test: 增加 repository 测试
This commit is contained in:
7
.github/workflows/backend-ci.yml
vendored
7
.github/workflows/backend-ci.yml
vendored
@@ -17,9 +17,12 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: true
|
check-latest: true
|
||||||
cache: true
|
cache: true
|
||||||
- name: Run tests
|
- name: Unit tests
|
||||||
working-directory: backend
|
working-directory: backend
|
||||||
run: go test ./...
|
run: make test-unit
|
||||||
|
- name: Integration tests
|
||||||
|
working-directory: backend
|
||||||
|
run: make test-integration
|
||||||
|
|
||||||
golangci-lint:
|
golangci-lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
.PHONY: wire build build-embed
|
.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage
|
||||||
|
|
||||||
wire:
|
wire:
|
||||||
@echo "生成 Wire 代码..."
|
@echo "生成 Wire 代码..."
|
||||||
@@ -14,3 +14,20 @@ build-embed:
|
|||||||
@echo "构建后端(嵌入前端)..."
|
@echo "构建后端(嵌入前端)..."
|
||||||
@go build -tags embed -o bin/server ./cmd/server
|
@go build -tags embed -o bin/server ./cmd/server
|
||||||
@echo "构建完成: bin/server (with embedded frontend)"
|
@echo "构建完成: bin/server (with embedded frontend)"
|
||||||
|
|
||||||
|
test-unit:
|
||||||
|
@go test ./... $(TEST_ARGS)
|
||||||
|
|
||||||
|
test-integration:
|
||||||
|
@go test -tags integration ./internal/repository -count=1 -race -parallel=8
|
||||||
|
|
||||||
|
test-cover-integration:
|
||||||
|
@echo "运行集成测试并生成覆盖率报告..."
|
||||||
|
@go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/...
|
||||||
|
@go tool cover -func=coverage.out | tail -1
|
||||||
|
@go tool cover -html=coverage.out -o coverage.html
|
||||||
|
@echo "覆盖率报告已生成: coverage.html"
|
||||||
|
|
||||||
|
clean-coverage:
|
||||||
|
@rm -f coverage.out coverage.html
|
||||||
|
@echo "覆盖率文件已清理"
|
||||||
@@ -11,8 +11,11 @@ require (
|
|||||||
github.com/google/wire v0.7.0
|
github.com/google/wire v0.7.0
|
||||||
github.com/imroc/req/v3 v3.56.0
|
github.com/imroc/req/v3 v3.56.0
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
github.com/redis/go-redis/v9 v9.3.0
|
github.com/redis/go-redis/v9 v9.7.3
|
||||||
github.com/spf13/viper v1.18.2
|
github.com/spf13/viper v1.18.2
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
golang.org/x/crypto v0.44.0
|
golang.org/x/crypto v0.44.0
|
||||||
@@ -24,52 +27,99 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
dario.cat/mergo v1.0.2 // indirect
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||||
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
github.com/google/go-querystring v1.1.0 // indirect
|
github.com/google/go-querystring v1.1.0 // indirect
|
||||||
github.com/google/subcommands v1.2.0 // indirect
|
github.com/google/subcommands v1.2.0 // indirect
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/icholy/digest v1.1.0 // indirect
|
github.com/icholy/digest v1.1.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||||
github.com/jackc/pgx/v5 v5.4.3 // indirect
|
github.com/jackc/pgx/v5 v5.5.4 // indirect
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/compress v1.18.1 // indirect
|
github.com/klauspost/compress v1.18.1 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
github.com/magiconair/properties v1.8.7 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0 // indirect
|
||||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
|
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||||
|
github.com/moby/sys/sequential v0.6.0 // indirect
|
||||||
|
github.com/moby/sys/user v0.4.0 // indirect
|
||||||
|
github.com/moby/sys/userns v0.1.0 // indirect
|
||||||
|
github.com/moby/term v0.5.0 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/quic-go/qpack v0.5.1 // indirect
|
github.com/quic-go/qpack v0.5.1 // indirect
|
||||||
github.com/quic-go/quic-go v0.56.0 // indirect
|
github.com/quic-go/quic-go v0.56.0 // indirect
|
||||||
github.com/refraction-networking/utls v1.8.1 // indirect
|
github.com/refraction-networking/utls v1.8.1 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||||
github.com/spf13/afero v1.11.0 // indirect
|
github.com/spf13/afero v1.11.0 // indirect
|
||||||
github.com/spf13/cast v1.6.0 // indirect
|
github.com/spf13/cast v1.6.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
|
go.opentelemetry.io/otel v1.37.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/sdk v1.37.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/trace v1.37.0 // indirect
|
||||||
go.uber.org/atomic v1.9.0 // indirect
|
go.uber.org/atomic v1.9.0 // indirect
|
||||||
go.uber.org/multierr v1.9.0 // indirect
|
go.uber.org/multierr v1.9.0 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
@@ -79,6 +129,7 @@ require (
|
|||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
golang.org/x/tools v0.38.0 // indirect
|
golang.org/x/tools v0.38.0 // indirect
|
||||||
google.golang.org/protobuf v1.31.0 // indirect
|
google.golang.org/grpc v1.75.1 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
160
backend/go.sum
160
backend/go.sum
@@ -1,3 +1,11 @@
|
|||||||
|
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
|
||||||
|
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
|
||||||
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||||
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
@@ -7,17 +15,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
|
|||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
|
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||||
|
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||||
|
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
|
||||||
|
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||||
|
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||||
|
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||||
@@ -28,6 +62,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
|||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
@@ -40,9 +81,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
|||||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
|
||||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
|
||||||
@@ -54,6 +94,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||||
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||||
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
|
||||||
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
|
||||||
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
||||||
@@ -64,8 +106,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
|||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
|
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
|
||||||
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||||
|
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
@@ -85,36 +129,70 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
|
|||||||
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
|
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
|
||||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||||
|
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||||
|
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||||
|
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||||
|
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||||
|
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||||
|
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||||
|
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||||
|
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||||
|
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||||
|
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||||
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
|
||||||
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||||
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY=
|
||||||
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c=
|
||||||
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
|
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||||
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||||
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
|
||||||
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
|
||||||
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
|
||||||
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
|
||||||
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
|
||||||
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||||
@@ -128,6 +206,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV
|
|||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
@@ -135,10 +215,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0=
|
||||||
|
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw=
|
||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
@@ -148,12 +234,36 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
|||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
|
||||||
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||||
|
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
|
||||||
|
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||||
|
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
|
||||||
|
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||||
|
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
|
||||||
|
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
|
||||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
|
||||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
@@ -173,8 +283,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
|||||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||||
@@ -186,9 +302,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
|||||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
|
||||||
|
google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI=
|
||||||
|
google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||||
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
@@ -201,6 +323,6 @@ gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
|
|||||||
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
|
||||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
||||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||||
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||||
|
|||||||
580
backend/internal/repository/account_repo_integration_test.go
Normal file
580
backend/internal/repository/account_repo_integration_test.go
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccountRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *AccountRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewAccountRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(AccountRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / Update / Delete ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestCreate() {
|
||||||
|
account := &model.Account{
|
||||||
|
Name: "test-create",
|
||||||
|
Platform: model.PlatformAnthropic,
|
||||||
|
Type: model.AccountTypeOAuth,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, account)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(account.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("test-create", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdate() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
|
||||||
|
|
||||||
|
account.Name = "updated"
|
||||||
|
err := s.repo.Update(s.ctx, account)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("updated", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestDelete() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
|
||||||
|
s.Require().Zero(count, "expected bindings to be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestList() {
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
|
||||||
|
|
||||||
|
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(accounts, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(db *gorm.DB)
|
||||||
|
platform string
|
||||||
|
accType string
|
||||||
|
status string
|
||||||
|
search string
|
||||||
|
wantCount int
|
||||||
|
validate func(accounts []model.Account)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "filter_by_platform",
|
||||||
|
setup: func(db *gorm.DB) {
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
|
||||||
|
},
|
||||||
|
platform: model.PlatformOpenAI,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []model.Account) {
|
||||||
|
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_type",
|
||||||
|
setup: func(db *gorm.DB) {
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
|
||||||
|
},
|
||||||
|
accType: model.AccountTypeApiKey,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []model.Account) {
|
||||||
|
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_status",
|
||||||
|
setup: func(db *gorm.DB) {
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
|
||||||
|
},
|
||||||
|
status: model.StatusDisabled,
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []model.Account) {
|
||||||
|
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter_by_search",
|
||||||
|
setup: func(db *gorm.DB) {
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
|
||||||
|
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
|
||||||
|
},
|
||||||
|
search: "alpha",
|
||||||
|
wantCount: 1,
|
||||||
|
validate: func(accounts []model.Account) {
|
||||||
|
s.Require().Contains(accounts[0].Name, "alpha")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
// 每个 case 重新获取隔离资源
|
||||||
|
db := testTx(s.T())
|
||||||
|
repo := NewAccountRepository(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tt.setup(db)
|
||||||
|
|
||||||
|
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(accounts, tt.wantCount)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(accounts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListByGroup() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||||
|
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
|
||||||
|
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
|
||||||
|
|
||||||
|
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "ListByGroup")
|
||||||
|
s.Require().Len(accounts, 2)
|
||||||
|
// Should be ordered by priority
|
||||||
|
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListActive() {
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
accounts, err := s.repo.ListActive(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActive")
|
||||||
|
s.Require().Len(accounts, 1)
|
||||||
|
s.Require().Equal("active1", accounts[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||||
|
|
||||||
|
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
|
||||||
|
s.Require().NoError(err, "ListByPlatform")
|
||||||
|
s.Require().Len(accounts, 1)
|
||||||
|
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Preload and VirtualFields ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||||
|
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||||
|
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||||
|
Name: "acc1",
|
||||||
|
ProxyID: &proxy.ID,
|
||||||
|
})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().NotNil(got.Proxy, "expected Proxy preload")
|
||||||
|
s.Require().Equal(proxy.ID, got.Proxy.ID)
|
||||||
|
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
|
||||||
|
s.Require().Equal(group.ID, got.GroupIDs[0])
|
||||||
|
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||||
|
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||||
|
|
||||||
|
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
s.Require().Len(accounts, 1)
|
||||||
|
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
|
||||||
|
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
|
||||||
|
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
|
||||||
|
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||||
|
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||||
|
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||||
|
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetGroups")
|
||||||
|
s.Require().Len(groups, 1, "expected 1 group")
|
||||||
|
s.Require().Equal(g1.ID, groups[0].ID)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
|
||||||
|
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetGroups after remove")
|
||||||
|
s.Require().Empty(groups, "expected 0 groups after remove")
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
|
||||||
|
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetGroups after bind")
|
||||||
|
s.Require().Len(groups, 2, "expected 2 groups after bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||||
|
|
||||||
|
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Empty(groups, "expected 0 groups after binding empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Schedulable ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||||
|
now := time.Now()
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||||
|
|
||||||
|
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||||
|
|
||||||
|
future := now.Add(10 * time.Minute)
|
||||||
|
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||||
|
|
||||||
|
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListSchedulable")
|
||||||
|
ids := idsOfAccounts(sched)
|
||||||
|
s.Require().Contains(ids, okAcc.ID)
|
||||||
|
s.Require().NotContains(ids, overloaded.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||||
|
now := time.Now()
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
|
||||||
|
|
||||||
|
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
|
||||||
|
|
||||||
|
future := now.Add(10 * time.Minute)
|
||||||
|
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
|
||||||
|
|
||||||
|
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
|
||||||
|
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||||
|
|
||||||
|
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "ListSchedulableByGroupID")
|
||||||
|
s.Require().Len(sched, 1, "expected only ok account schedulable")
|
||||||
|
s.Require().Equal(okAcc.ID, sched[0].ID)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
|
||||||
|
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
|
||||||
|
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||||
|
|
||||||
|
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(accounts, 1)
|
||||||
|
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
|
||||||
|
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||||
|
|
||||||
|
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(accounts, 1)
|
||||||
|
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(got.Schedulable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
|
||||||
|
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.OverloadUntil)
|
||||||
|
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
|
||||||
|
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.RateLimitedAt)
|
||||||
|
s.Require().NotNil(got.RateLimitResetAt)
|
||||||
|
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
|
||||||
|
until := time.Now().Add(1 * time.Hour)
|
||||||
|
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||||
|
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Nil(got.RateLimitedAt)
|
||||||
|
s.Require().Nil(got.RateLimitResetAt)
|
||||||
|
s.Require().Nil(got.OverloadUntil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UpdateLastUsed ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
|
||||||
|
s.Require().Nil(account.LastUsedAt)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.LastUsedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SetError ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestSetError() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(model.StatusError, got.Status)
|
||||||
|
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UpdateSessionWindow ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
|
||||||
|
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.SessionWindowStart)
|
||||||
|
s.Require().NotNil(got.SessionWindowEnd)
|
||||||
|
s.Require().Equal("active", got.SessionWindowStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UpdateExtra ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||||
|
Name: "acc-extra",
|
||||||
|
Extra: model.JSONB{"a": "1"},
|
||||||
|
})
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("1", got.Extra["a"])
|
||||||
|
s.Require().Equal("2", got.Extra["b"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
|
||||||
|
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("val", got.Extra["key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetByCRSAccountID ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||||
|
crsID := "crs-12345"
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{
|
||||||
|
Name: "acc-crs",
|
||||||
|
Extra: model.JSONB{"crs_account_id": crsID},
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got)
|
||||||
|
s.Require().Equal("acc-crs", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
|
||||||
|
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Nil(got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||||
|
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Nil(got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- BulkUpdate ---
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
|
||||||
|
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
|
||||||
|
|
||||||
|
newPriority := 99
|
||||||
|
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{
|
||||||
|
Priority: &newPriority,
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
|
||||||
|
|
||||||
|
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||||
|
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
|
||||||
|
s.Require().Equal(99, got1.Priority)
|
||||||
|
s.Require().Equal(99, got2.Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||||
|
Name: "bulk-cred",
|
||||||
|
Credentials: model.JSONB{"existing": "value"},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||||
|
Credentials: model.JSONB{"new_key": "new_value"},
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||||
|
s.Require().Equal("value", got.Credentials["existing"])
|
||||||
|
s.Require().Equal("new_value", got.Credentials["new_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
|
||||||
|
Name: "bulk-extra",
|
||||||
|
Extra: model.JSONB{"existing": "val"},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
|
||||||
|
Extra: model.JSONB{"new_key": "new_val"},
|
||||||
|
})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||||
|
s.Require().Equal("val", got.Extra["existing"])
|
||||||
|
s.Require().Equal("new_val", got.Extra["new_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||||
|
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
|
||||||
|
|
||||||
|
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func idsOfAccounts(accounts []model.Account) []int64 {
|
||||||
|
out := make([]int64, 0, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
out = append(out, accounts[i].ID)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
125
backend/internal/repository/api_key_cache_integration_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ApiKeyCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing_key_returns_redis_nil",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||||
|
userID := int64(1)
|
||||||
|
|
||||||
|
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||||
|
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "increment_increases_count_and_sets_ttl",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||||
|
userID := int64(1)
|
||||||
|
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||||
|
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||||
|
|
||||||
|
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||||
|
require.Equal(s.T(), 2, count, "count mismatch")
|
||||||
|
|
||||||
|
ttl, err := rdb.TTL(ctx, key).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete_removes_key",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||||
|
userID := int64(1)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||||
|
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||||
|
|
||||||
|
_, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
// 每个 case 重新获取隔离资源
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := &apiKeyCache{rdb: rdb}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tt.fn(ctx, rdb, cache)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "increment_increases_count",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||||
|
dailyKey := "daily:sk-test"
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||||
|
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||||
|
|
||||||
|
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||||
|
require.NoError(s.T(), err, "Get dailyKey")
|
||||||
|
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_expiry_sets_ttl",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||||
|
dailyKey := "daily:sk-test-expiry"
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||||
|
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||||
|
|
||||||
|
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL dailyKey")
|
||||||
|
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := &apiKeyCache{rdb: rdb}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tt.fn(ctx, rdb, cache)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiKeyCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ApiKeyCacheSuite))
|
||||||
|
}
|
||||||
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
355
backend/internal/repository/api_key_repo_integration_test.go
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ApiKeyRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *ApiKeyRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewApiKeyRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiKeyRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ApiKeyRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / GetByKey ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||||
|
|
||||||
|
key := &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-create-test",
|
||||||
|
Name: "Test Key",
|
||||||
|
Status: model.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, key)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("sk-create-test", got.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
|
||||||
|
|
||||||
|
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-getbykey",
|
||||||
|
Name: "My Key",
|
||||||
|
GroupID: &group.ID,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||||
|
s.Require().NoError(err, "GetByKey")
|
||||||
|
s.Require().Equal(key.ID, got.ID)
|
||||||
|
s.Require().NotNil(got.User, "expected User preload")
|
||||||
|
s.Require().Equal(user.ID, got.User.ID)
|
||||||
|
s.Require().NotNil(got.Group, "expected Group preload")
|
||||||
|
s.Require().Equal(group.ID, got.Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||||
|
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||||
|
s.Require().Error(err, "expected error for non-existent key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Update ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||||
|
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-update",
|
||||||
|
Name: "Original",
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
key.Name = "Renamed"
|
||||||
|
key.Status = model.StatusDisabled
|
||||||
|
err := s.repo.Update(s.ctx, key)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||||
|
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||||
|
s.Require().Equal("Renamed", got.Name)
|
||||||
|
s.Require().Equal(model.StatusDisabled, got.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
|
||||||
|
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-clear-group",
|
||||||
|
Name: "Group Key",
|
||||||
|
GroupID: &group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
key.GroupID = nil
|
||||||
|
err := s.repo.Update(s.ctx, key)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Delete ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||||
|
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-delete",
|
||||||
|
Name: "Delete Me",
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByUserID / CountByUserID ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
|
||||||
|
|
||||||
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByUserID")
|
||||||
|
s.Require().Len(keys, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-page-" + string(rune('a'+i)),
|
||||||
|
Name: "Key",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(keys, 2)
|
||||||
|
s.Require().Equal(int64(5), page.Total)
|
||||||
|
s.Require().Equal(3, page.Pages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
|
||||||
|
|
||||||
|
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "CountByUserID")
|
||||||
|
s.Require().Equal(int64(2), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByGroupID / CountByGroupID ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||||
|
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
|
||||||
|
|
||||||
|
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByGroupID")
|
||||||
|
s.Require().Len(keys, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
// User preloaded
|
||||||
|
s.Require().NotNil(keys[0].User)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||||
|
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
|
||||||
|
|
||||||
|
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "CountByGroupID")
|
||||||
|
s.Require().Equal(int64(1), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ExistsByKey ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||||
|
s.Require().NoError(err, "ExistsByKey")
|
||||||
|
s.Require().True(exists)
|
||||||
|
|
||||||
|
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(notExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SearchApiKeys ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
|
||||||
|
|
||||||
|
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||||
|
s.Require().NoError(err, "SearchApiKeys")
|
||||||
|
s.Require().Len(found, 1)
|
||||||
|
s.Require().Contains(found[0].Name, "Production")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
|
||||||
|
|
||||||
|
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(found, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
|
||||||
|
|
||||||
|
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(found, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ClearGroupIDByGroupID ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
|
||||||
|
|
||||||
|
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
|
||||||
|
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
|
||||||
|
|
||||||
|
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||||
|
s.Require().Equal(int64(2), affected)
|
||||||
|
|
||||||
|
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||||
|
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||||
|
s.Require().Nil(got1.GroupID)
|
||||||
|
s.Require().Nil(got2.GroupID)
|
||||||
|
|
||||||
|
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().Zero(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||||
|
|
||||||
|
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
|
||||||
|
|
||||||
|
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-test-1",
|
||||||
|
Name: "My Key",
|
||||||
|
GroupID: &group.ID,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||||
|
s.Require().NoError(err, "GetByKey")
|
||||||
|
s.Require().Equal(key.ID, got.ID)
|
||||||
|
s.Require().NotNil(got.User)
|
||||||
|
s.Require().Equal(user.ID, got.User.ID)
|
||||||
|
s.Require().NotNil(got.Group)
|
||||||
|
s.Require().Equal(group.ID, got.Group.ID)
|
||||||
|
|
||||||
|
key.Name = "Renamed"
|
||||||
|
key.Status = model.StatusDisabled
|
||||||
|
key.GroupID = nil
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||||
|
|
||||||
|
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||||
|
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||||
|
s.Require().Equal("Renamed", got2.Name)
|
||||||
|
s.Require().Equal(model.StatusDisabled, got2.Status)
|
||||||
|
s.Require().Nil(got2.GroupID)
|
||||||
|
|
||||||
|
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByUserID")
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
s.Require().Len(keys, 1)
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||||
|
s.Require().NoError(err, "ExistsByKey")
|
||||||
|
s.Require().True(exists, "expected key to exist")
|
||||||
|
|
||||||
|
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||||
|
s.Require().NoError(err, "SearchApiKeys")
|
||||||
|
s.Require().Len(found, 1)
|
||||||
|
s.Require().Equal(key.ID, found[0].ID)
|
||||||
|
|
||||||
|
// ClearGroupIDByGroupID
|
||||||
|
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
Key: "sk-test-2",
|
||||||
|
Name: "Group Key",
|
||||||
|
GroupID: &group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "CountByGroupID")
|
||||||
|
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||||
|
|
||||||
|
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||||
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
|
||||||
|
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||||
|
|
||||||
|
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "CountByGroupID after clear")
|
||||||
|
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||||
|
}
|
||||||
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BillingCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BillingCacheSuite) TestUserBalance() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing_key_returns_redis_nil",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
_, err := cache.GetUserBalance(ctx, 1)
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deduct_on_nonexistent_is_noop",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(1)
|
||||||
|
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||||
|
|
||||||
|
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_and_get_with_ttl",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(2)
|
||||||
|
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||||
|
|
||||||
|
got, err := cache.GetUserBalance(ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetUserBalance")
|
||||||
|
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||||
|
|
||||||
|
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deduct_reduces_balance",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(3)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||||
|
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||||
|
|
||||||
|
got, err := cache.GetUserBalance(ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||||
|
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalidate_removes_key",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(100)
|
||||||
|
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||||
|
|
||||||
|
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||||
|
require.NoError(s.T(), err, "Exists")
|
||||||
|
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||||
|
|
||||||
|
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||||
|
require.NoError(s.T(), err, "Exists after invalidate")
|
||||||
|
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||||
|
|
||||||
|
_, err = cache.GetUserBalance(ctx, userID)
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deduct_refreshes_ttl",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(103)
|
||||||
|
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||||
|
|
||||||
|
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL before deduct")
|
||||||
|
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||||
|
|
||||||
|
balance, err := cache.GetUserBalance(ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetUserBalance")
|
||||||
|
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||||
|
|
||||||
|
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL after deduct")
|
||||||
|
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := NewBillingCache(rdb)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tt.fn(ctx, rdb, cache)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing_key_returns_redis_nil",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(10)
|
||||||
|
groupID := int64(20)
|
||||||
|
|
||||||
|
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update_usage_on_nonexistent_is_noop",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(11)
|
||||||
|
groupID := int64(21)
|
||||||
|
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||||
|
|
||||||
|
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||||
|
require.NoError(s.T(), err, "Exists")
|
||||||
|
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set_and_get_with_ttl",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(12)
|
||||||
|
groupID := int64(22)
|
||||||
|
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
data := &ports.SubscriptionCacheData{
|
||||||
|
Status: "active",
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
DailyUsage: 1.0,
|
||||||
|
WeeklyUsage: 2.0,
|
||||||
|
MonthlyUsage: 3.0,
|
||||||
|
Version: 7,
|
||||||
|
}
|
||||||
|
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||||
|
|
||||||
|
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
|
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||||
|
require.Equal(s.T(), "active", gotSub.Status)
|
||||||
|
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||||
|
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||||
|
|
||||||
|
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL subKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update_usage_increments_all_fields",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(13)
|
||||||
|
groupID := int64(23)
|
||||||
|
|
||||||
|
data := &ports.SubscriptionCacheData{
|
||||||
|
Status: "active",
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
DailyUsage: 1.0,
|
||||||
|
WeeklyUsage: 2.0,
|
||||||
|
MonthlyUsage: 3.0,
|
||||||
|
Version: 1,
|
||||||
|
}
|
||||||
|
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||||
|
|
||||||
|
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
|
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||||
|
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||||
|
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||||
|
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalidate_removes_key",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(101)
|
||||||
|
groupID := int64(10)
|
||||||
|
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
data := &ports.SubscriptionCacheData{
|
||||||
|
Status: "active",
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
DailyUsage: 1.0,
|
||||||
|
WeeklyUsage: 2.0,
|
||||||
|
MonthlyUsage: 3.0,
|
||||||
|
Version: 1,
|
||||||
|
}
|
||||||
|
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||||
|
|
||||||
|
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||||
|
require.NoError(s.T(), err, "Exists")
|
||||||
|
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||||
|
|
||||||
|
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||||
|
|
||||||
|
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||||
|
require.NoError(s.T(), err, "Exists after invalidate")
|
||||||
|
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||||
|
|
||||||
|
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
|
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_status_returns_parsing_error",
|
||||||
|
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
|
||||||
|
userID := int64(102)
|
||||||
|
groupID := int64(11)
|
||||||
|
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||||
|
|
||||||
|
fields := map[string]any{
|
||||||
|
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||||
|
"daily_usage": 1.0,
|
||||||
|
"weekly_usage": 2.0,
|
||||||
|
"monthly_usage": 3.0,
|
||||||
|
"version": 1,
|
||||||
|
}
|
||||||
|
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||||
|
|
||||||
|
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||||
|
require.Error(s.T(), err, "expected error for missing status field")
|
||||||
|
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||||
|
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := NewBillingCache(rdb)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tt.fn(ctx, rdb, cache)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBillingCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(BillingCacheSuite))
|
||||||
|
}
|
||||||
@@ -16,20 +16,28 @@ import (
|
|||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeOAuthService struct{}
|
|
||||||
|
|
||||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||||
return &claudeOAuthService{}
|
return &claudeOAuthService{
|
||||||
|
baseURL: "https://claude.ai",
|
||||||
|
tokenURL: oauth.TokenURL,
|
||||||
|
clientFactory: createReqClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type claudeOAuthService struct {
|
||||||
|
baseURL string
|
||||||
|
tokenURL string
|
||||||
|
clientFactory func(proxyURL string) *req.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||||
client := createReqClient(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
var orgs []struct {
|
var orgs []struct {
|
||||||
UUID string `json:"uuid"`
|
UUID string `json:"uuid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL := "https://claude.ai/api/organizations"
|
targetURL := s.baseURL + "/api/organizations"
|
||||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||||
client := createReqClient(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
|
||||||
|
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
|||||||
fullCode = authCode + "#" + responseState
|
fullCode = authCode + "#" + responseState
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
|
||||||
return fullCode, nil
|
return fullCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
client := createReqClient(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
// Parse code which may contain state in format "authCode#state"
|
// Parse code which may contain state in format "authCode#state"
|
||||||
authCode := code
|
authCode := code
|
||||||
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
}
|
}
|
||||||
|
|
||||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||||
|
|
||||||
var tokenResp oauth.TokenResponse
|
var tokenResp oauth.TokenResponse
|
||||||
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
SetHeader("Content-Type", "application/json").
|
SetHeader("Content-Type", "application/json").
|
||||||
SetBody(reqBody).
|
SetBody(reqBody).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(oauth.TokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||||
@@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
client := createReqClient(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
formData := url.Values{}
|
||||||
formData.Set("grant_type", "refresh_token")
|
formData.Set("grant_type", "refresh_token")
|
||||||
@@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
SetFormDataFromValues(formData).
|
SetFormDataFromValues(formData).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(oauth.TokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
@@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
|
|||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func prefix(s string, n int) string {
|
||||||
|
if n <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(s) <= n {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:n]
|
||||||
|
}
|
||||||
|
|||||||
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
343
backend/internal/repository/claude_oauth_service_test.go
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeOAuthServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
srv *httptest.Server
|
||||||
|
client *claudeOAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||||
|
type requestCapture struct {
|
||||||
|
path string
|
||||||
|
method string
|
||||||
|
cookies []*http.Cookie
|
||||||
|
body []byte
|
||||||
|
formValues url.Values
|
||||||
|
bodyJSON map[string]any
|
||||||
|
contentType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handler http.HandlerFunc
|
||||||
|
wantErr bool
|
||||||
|
errContain string
|
||||||
|
wantUUID string
|
||||||
|
validate func(captured requestCapture)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||||
|
},
|
||||||
|
wantUUID: "org-1",
|
||||||
|
validate: func(captured requestCapture) {
|
||||||
|
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||||
|
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||||
|
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||||
|
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_200_returns_error",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte("unauthorized"))
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errContain: "401",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_json_returns_error",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte("not-json"))
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var captured requestCapture
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.path = r.URL.Path
|
||||||
|
captured.cookies = r.Cookies()
|
||||||
|
tt.handler(w, r)
|
||||||
|
}))
|
||||||
|
defer s.srv.Close()
|
||||||
|
|
||||||
|
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
s.client.baseURL = s.srv.URL
|
||||||
|
|
||||||
|
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
if tt.errContain != "" {
|
||||||
|
require.ErrorContains(s.T(), err, tt.errContain)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), tt.wantUUID, got)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(captured)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handler http.HandlerFunc
|
||||||
|
wantErr bool
|
||||||
|
wantCode string
|
||||||
|
validate func(captured requestCapture)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parses_redirect_uri",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantCode: "AUTH#STATE",
|
||||||
|
validate: func(captured requestCapture) {
|
||||||
|
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||||
|
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||||
|
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||||
|
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||||
|
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||||
|
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||||
|
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||||
|
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing_code_returns_error",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var captured requestCapture
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.path = r.URL.Path
|
||||||
|
captured.method = r.Method
|
||||||
|
captured.cookies = r.Cookies()
|
||||||
|
captured.body, _ = io.ReadAll(r.Body)
|
||||||
|
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||||
|
tt.handler(w, r)
|
||||||
|
}))
|
||||||
|
defer s.srv.Close()
|
||||||
|
|
||||||
|
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
s.client.baseURL = s.srv.URL
|
||||||
|
|
||||||
|
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), tt.wantCode, code)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(captured)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handler http.HandlerFunc
|
||||||
|
code string
|
||||||
|
wantErr bool
|
||||||
|
wantResp *oauth.TokenResponse
|
||||||
|
validate func(captured requestCapture)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "sends_state_when_embedded",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
TokenType: "bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
RefreshToken: "rt",
|
||||||
|
Scope: "s",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
code: "AUTH#STATE2",
|
||||||
|
wantResp: &oauth.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
RefreshToken: "rt",
|
||||||
|
},
|
||||||
|
validate: func(captured requestCapture) {
|
||||||
|
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||||
|
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||||
|
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||||
|
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||||
|
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||||
|
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||||
|
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_200_returns_error",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte("bad request"))
|
||||||
|
},
|
||||||
|
code: "AUTH",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var captured requestCapture
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.method = r.Method
|
||||||
|
captured.contentType = r.Header.Get("Content-Type")
|
||||||
|
captured.body, _ = io.ReadAll(r.Body)
|
||||||
|
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||||
|
tt.handler(w, r)
|
||||||
|
}))
|
||||||
|
defer s.srv.Close()
|
||||||
|
|
||||||
|
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
s.client.tokenURL = s.srv.URL
|
||||||
|
|
||||||
|
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||||
|
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(captured)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handler http.HandlerFunc
|
||||||
|
wantErr bool
|
||||||
|
wantResp *oauth.TokenResponse
|
||||||
|
validate func(captured requestCapture)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "sends_form",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
||||||
|
},
|
||||||
|
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
||||||
|
validate: func(captured requestCapture) {
|
||||||
|
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||||
|
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
||||||
|
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
||||||
|
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_200_returns_error",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte("unauthorized"))
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var captured requestCapture
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.method = r.Method
|
||||||
|
captured.body, _ = io.ReadAll(r.Body)
|
||||||
|
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
||||||
|
tt.handler(w, r)
|
||||||
|
}))
|
||||||
|
defer s.srv.Close()
|
||||||
|
|
||||||
|
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
s.client.tokenURL = s.srv.URL
|
||||||
|
|
||||||
|
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||||
|
if tt.validate != nil {
|
||||||
|
tt.validate(captured)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||||
|
}
|
||||||
@@ -12,10 +12,14 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
type claudeUsageService struct{}
|
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||||
|
|
||||||
|
type claudeUsageService struct {
|
||||||
|
usageURL string
|
||||||
|
}
|
||||||
|
|
||||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||||
return &claudeUsageService{}
|
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
|||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request failed: %w", err)
|
return nil, fmt.Errorf("create request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
105
backend/internal/repository/claude_usage_service_test.go
Normal file
105
backend/internal/repository/claude_usage_service_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeUsageServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
srv *httptest.Server
|
||||||
|
fetcher *claudeUsageService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||||
|
type usageRequestCapture struct {
|
||||||
|
authorization string
|
||||||
|
anthropicBeta string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||||
|
var captured usageRequestCapture
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
captured.authorization = r.Header.Get("Authorization")
|
||||||
|
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{
|
||||||
|
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||||
|
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||||
|
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||||
|
}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||||
|
|
||||||
|
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||||
|
require.NoError(s.T(), err, "FetchUsage")
|
||||||
|
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||||
|
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||||
|
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||||
|
|
||||||
|
// Assertions on captured request data
|
||||||
|
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||||
|
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = io.WriteString(w, "nope")
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||||
|
|
||||||
|
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "status 401")
|
||||||
|
require.ErrorContains(s.T(), err, "nope")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, "not-json")
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||||
|
|
||||||
|
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "decode response failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Never respond - simulate slow server
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||||
|
require.Error(s.T(), err, "expected error for cancelled context")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||||
|
}
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConcurrencyCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache ports.ConcurrencyCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewConcurrencyCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||||
|
accountID := int64(10)
|
||||||
|
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||||
|
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||||
|
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||||
|
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||||
|
require.False(s.T(), ok, "expected third acquire to fail")
|
||||||
|
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||||
|
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||||
|
|
||||||
|
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||||
|
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||||
|
accountID := int64(11)
|
||||||
|
reqID := "req_ttl_test"
|
||||||
|
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||||
|
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||||
|
accountID := int64(12)
|
||||||
|
reqID := "dup-req"
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Acquiring with same reqID should be idempotent
|
||||||
|
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||||
|
accountID := int64(13)
|
||||||
|
reqID := "release-test"
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||||
|
// Releasing again should not error
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||||
|
// Releasing non-existent should not error
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||||
|
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 0, cur)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||||
|
accountID := int64(14)
|
||||||
|
reqID := "max-zero-test"
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||||
|
userID := int64(42)
|
||||||
|
reqID1, reqID2 := "req1", "req2"
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||||
|
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||||
|
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||||
|
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||||
|
|
||||||
|
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||||
|
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||||
|
// Releasing a non-existent slot should not error
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||||
|
|
||||||
|
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||||
|
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||||
|
userID := int64(200)
|
||||||
|
reqID := "req_ttl_test"
|
||||||
|
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||||
|
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||||
|
userID := int64(20)
|
||||||
|
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
|
||||||
|
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||||
|
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||||
|
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL waitKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||||
|
|
||||||
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
|
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||||
|
userID := int64(300)
|
||||||
|
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||||
|
|
||||||
|
// Test decrement on non-existent key - should not error and should not create negative value
|
||||||
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||||
|
|
||||||
|
// Verify no key was created or it's not negative
|
||||||
|
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey")
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||||
|
|
||||||
|
// Set count to 1, then decrement twice
|
||||||
|
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||||
|
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Decrement once (1 -> 0)
|
||||||
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||||
|
|
||||||
|
// Decrement again on 0 - should not go negative
|
||||||
|
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||||
|
|
||||||
|
// Verify count is 0, not negative
|
||||||
|
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||||
|
if !errors.Is(err, redis.Nil) {
|
||||||
|
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||||
|
}
|
||||||
|
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||||
|
// When no slots exist, GetAccountConcurrency should return 0
|
||||||
|
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 0, cur)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||||
|
// When no slots exist, GetUserConcurrency should return 0
|
||||||
|
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 0, cur)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||||
|
}
|
||||||
92
backend/internal/repository/email_cache_integration_test.go
Normal file
92
backend/internal/repository/email_cache_integration_test.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmailCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache ports.EmailCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewEmailCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||||
|
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||||
|
email := "a@example.com"
|
||||||
|
emailTTL := 2 * time.Minute
|
||||||
|
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||||
|
|
||||||
|
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||||
|
require.NoError(s.T(), err, "GetVerificationCode")
|
||||||
|
require.Equal(s.T(), "123456", got.Code)
|
||||||
|
require.Equal(s.T(), 1, got.Attempts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||||
|
email := "ttl@example.com"
|
||||||
|
emailTTL := 2 * time.Minute
|
||||||
|
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||||
|
|
||||||
|
emailKey := verifyCodeKeyPrefix + email
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL emailKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||||
|
email := "delete@example.com"
|
||||||
|
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||||
|
|
||||||
|
// Verify it exists
|
||||||
|
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||||
|
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||||
|
// Deleting a non-existent key should not error
|
||||||
|
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||||
|
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||||
|
|
||||||
|
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||||
|
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||||
|
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmailCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(EmailCacheSuite))
|
||||||
|
}
|
||||||
172
backend/internal/repository/fixtures_integration_test.go
Normal file
172
backend/internal/repository/fixtures_integration_test.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
|
||||||
|
t.Helper()
|
||||||
|
if u.PasswordHash == "" {
|
||||||
|
u.PasswordHash = "test-password-hash"
|
||||||
|
}
|
||||||
|
if u.Role == "" {
|
||||||
|
u.Role = model.RoleUser
|
||||||
|
}
|
||||||
|
if u.Status == "" {
|
||||||
|
u.Status = model.StatusActive
|
||||||
|
}
|
||||||
|
if u.CreatedAt.IsZero() {
|
||||||
|
u.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if u.UpdatedAt.IsZero() {
|
||||||
|
u.UpdatedAt = u.CreatedAt
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(u).Error, "create user")
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
|
||||||
|
t.Helper()
|
||||||
|
if g.Platform == "" {
|
||||||
|
g.Platform = model.PlatformAnthropic
|
||||||
|
}
|
||||||
|
if g.Status == "" {
|
||||||
|
g.Status = model.StatusActive
|
||||||
|
}
|
||||||
|
if g.SubscriptionType == "" {
|
||||||
|
g.SubscriptionType = model.SubscriptionTypeStandard
|
||||||
|
}
|
||||||
|
if g.CreatedAt.IsZero() {
|
||||||
|
g.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if g.UpdatedAt.IsZero() {
|
||||||
|
g.UpdatedAt = g.CreatedAt
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(g).Error, "create group")
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
|
||||||
|
t.Helper()
|
||||||
|
if p.Protocol == "" {
|
||||||
|
p.Protocol = "http"
|
||||||
|
}
|
||||||
|
if p.Host == "" {
|
||||||
|
p.Host = "127.0.0.1"
|
||||||
|
}
|
||||||
|
if p.Port == 0 {
|
||||||
|
p.Port = 8080
|
||||||
|
}
|
||||||
|
if p.Status == "" {
|
||||||
|
p.Status = model.StatusActive
|
||||||
|
}
|
||||||
|
if p.CreatedAt.IsZero() {
|
||||||
|
p.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if p.UpdatedAt.IsZero() {
|
||||||
|
p.UpdatedAt = p.CreatedAt
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(p).Error, "create proxy")
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
|
||||||
|
t.Helper()
|
||||||
|
if a.Platform == "" {
|
||||||
|
a.Platform = model.PlatformAnthropic
|
||||||
|
}
|
||||||
|
if a.Type == "" {
|
||||||
|
a.Type = model.AccountTypeOAuth
|
||||||
|
}
|
||||||
|
if a.Status == "" {
|
||||||
|
a.Status = model.StatusActive
|
||||||
|
}
|
||||||
|
if !a.Schedulable {
|
||||||
|
a.Schedulable = true
|
||||||
|
}
|
||||||
|
if a.Credentials == nil {
|
||||||
|
a.Credentials = model.JSONB{}
|
||||||
|
}
|
||||||
|
if a.Extra == nil {
|
||||||
|
a.Extra = model.JSONB{}
|
||||||
|
}
|
||||||
|
if a.CreatedAt.IsZero() {
|
||||||
|
a.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if a.UpdatedAt.IsZero() {
|
||||||
|
a.UpdatedAt = a.CreatedAt
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(a).Error, "create account")
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
|
||||||
|
t.Helper()
|
||||||
|
if k.Status == "" {
|
||||||
|
k.Status = model.StatusActive
|
||||||
|
}
|
||||||
|
if k.CreatedAt.IsZero() {
|
||||||
|
k.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if k.UpdatedAt.IsZero() {
|
||||||
|
k.UpdatedAt = k.CreatedAt
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(k).Error, "create api key")
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
|
||||||
|
t.Helper()
|
||||||
|
if c.Status == "" {
|
||||||
|
c.Status = model.StatusUnused
|
||||||
|
}
|
||||||
|
if c.Type == "" {
|
||||||
|
c.Type = model.RedeemTypeBalance
|
||||||
|
}
|
||||||
|
if c.CreatedAt.IsZero() {
|
||||||
|
c.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(c).Error, "create redeem code")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
|
||||||
|
t.Helper()
|
||||||
|
if s.Status == "" {
|
||||||
|
s.Status = model.SubscriptionStatusActive
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if s.StartsAt.IsZero() {
|
||||||
|
s.StartsAt = now.Add(-1 * time.Hour)
|
||||||
|
}
|
||||||
|
if s.ExpiresAt.IsZero() {
|
||||||
|
s.ExpiresAt = now.Add(24 * time.Hour)
|
||||||
|
}
|
||||||
|
if s.AssignedAt.IsZero() {
|
||||||
|
s.AssignedAt = now
|
||||||
|
}
|
||||||
|
if s.CreatedAt.IsZero() {
|
||||||
|
s.CreatedAt = now
|
||||||
|
}
|
||||||
|
if s.UpdatedAt.IsZero() {
|
||||||
|
s.UpdatedAt = now
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(s).Error, "create user subscription")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, db.Create(&model.AccountGroup{
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
Priority: priority,
|
||||||
|
}).Error, "create account_group")
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GatewayCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache ports.GatewayCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewGatewayCache(s.rdb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||||
|
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||||
|
sessionID := "s1"
|
||||||
|
accountID := int64(99)
|
||||||
|
sessionTTL := 1 * time.Minute
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||||
|
|
||||||
|
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||||
|
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||||
|
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||||
|
sessionID := "s2"
|
||||||
|
accountID := int64(100)
|
||||||
|
sessionTTL := 1 * time.Minute
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||||
|
|
||||||
|
sessionKey := stickySessionPrefix + sessionID
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||||
|
sessionID := "s3"
|
||||||
|
accountID := int64(101)
|
||||||
|
initialTTL := 1 * time.Minute
|
||||||
|
refreshTTL := 3 * time.Minute
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||||
|
|
||||||
|
sessionKey := stickySessionPrefix + sessionID
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL after Refresh")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||||
|
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||||
|
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
|
||||||
|
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||||
|
sessionID := "corrupted"
|
||||||
|
sessionKey := stickySessionPrefix + sessionID
|
||||||
|
|
||||||
|
// Set a non-integer value
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||||
|
|
||||||
|
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
|
||||||
|
require.Error(s.T(), err, "expected error for corrupted value")
|
||||||
|
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GatewayCacheSuite))
|
||||||
|
}
|
||||||
328
backend/internal/repository/github_release_service_test.go
Normal file
328
backend/internal/repository/github_release_service_test.go
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GitHubReleaseServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
srv *httptest.Server
|
||||||
|
client *githubReleaseClient
|
||||||
|
tempDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// testTransport redirects requests to the test server
|
||||||
|
type testTransport struct {
|
||||||
|
testServerURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Rewrite the URL to point to our test server
|
||||||
|
testURL := t.testServerURL + req.URL.Path
|
||||||
|
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newReq.Header = req.Header
|
||||||
|
return http.DefaultTransport.RoundTrip(newReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||||
|
s.tempDir = s.T().TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Length", "100")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||||
|
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
|
||||||
|
|
||||||
|
_, statErr := os.Stat(dest)
|
||||||
|
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if fl, ok := w.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||||
|
if fl, ok := w.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||||
|
require.Error(s.T(), err, "expected error for oversized chunked download")
|
||||||
|
|
||||||
|
_, statErr := os.Stat(dest)
|
||||||
|
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if fl, ok := w.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||||
|
if fl, ok := w.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||||
|
require.NoError(s.T(), err, "expected success")
|
||||||
|
|
||||||
|
b, err := os.ReadFile(dest)
|
||||||
|
require.NoError(s.T(), err, "read")
|
||||||
|
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
|
||||||
|
require.Len(s.T(), b, 100, "downloaded content length mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||||
|
require.Error(s.T(), err, "expected error for 404")
|
||||||
|
|
||||||
|
_, statErr := os.Stat(dest)
|
||||||
|
require.Error(s.T(), statErr, "expected file to not exist for 404")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("sum"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||||
|
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||||
|
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||||
|
require.Error(s.T(), err, "expected error for non-200")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "cancelled.bin")
|
||||||
|
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
|
||||||
|
require.Error(s.T(), err, "expected error for cancelled context")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||||
|
require.Error(s.T(), err, "expected error for invalid URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("content"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
// Use a path that cannot be created (directory doesn't exist)
|
||||||
|
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||||
|
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||||
|
require.Error(s.T(), err, "expected error for invalid destination path")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||||
|
require.Error(s.T(), err, "expected error for invalid URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||||
|
releaseJSON := `{
|
||||||
|
"tag_name": "v1.0.0",
|
||||||
|
"name": "Release 1.0.0",
|
||||||
|
"body": "Release notes",
|
||||||
|
"html_url": "https://github.com/test/repo/releases/v1.0.0",
|
||||||
|
"assets": [
|
||||||
|
{
|
||||||
|
"name": "app-linux-amd64.tar.gz",
|
||||||
|
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||||
|
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||||
|
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(releaseJSON))
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Use custom transport to redirect requests to test server
|
||||||
|
s.client = &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), "v1.0.0", release.TagName)
|
||||||
|
require.Equal(s.T(), "Release 1.0.0", release.Name)
|
||||||
|
require.Len(s.T(), release.Assets, 1)
|
||||||
|
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.client = &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.Contains(s.T(), err.Error(), "404")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("not valid json"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.client = &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
|
||||||
|
s.client = &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||||
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
<-r.Context().Done()
|
||||||
|
}))
|
||||||
|
|
||||||
|
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubReleaseServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GitHubReleaseServiceSuite))
|
||||||
|
}
|
||||||
244
backend/internal/repository/group_repo_integration_test.go
Normal file
244
backend/internal/repository/group_repo_integration_test.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *GroupRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewGroupRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(GroupRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / Update / Delete ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestCreate() {
|
||||||
|
group := &model.Group{
|
||||||
|
Name: "test-create",
|
||||||
|
Platform: model.PlatformAnthropic,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, group)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(group.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("test-create", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestUpdate() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
|
||||||
|
|
||||||
|
group.Name = "updated"
|
||||||
|
err := s.repo.Update(s.ctx, group)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("updated", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestDelete() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestList() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
|
||||||
|
|
||||||
|
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(groups, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
|
||||||
|
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().Equal(model.StatusDisabled, groups[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
|
||||||
|
|
||||||
|
isExclusive := true
|
||||||
|
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().True(groups[0].IsExclusive)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||||
|
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||||
|
Name: "g1",
|
||||||
|
Platform: model.PlatformAnthropic,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
|
||||||
|
Name: "g2",
|
||||||
|
Platform: model.PlatformAnthropic,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
IsExclusive: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
|
||||||
|
|
||||||
|
isExclusive := true
|
||||||
|
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
|
||||||
|
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListActive / ListActiveByPlatform ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListActive() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
groups, err := s.repo.ListActive(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActive")
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().Equal("active1", groups[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
|
||||||
|
s.Require().NoError(err, "ListActiveByPlatform")
|
||||||
|
s.Require().Len(groups, 1)
|
||||||
|
s.Require().Equal("g1", groups[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ExistsByName ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestExistsByName() {
|
||||||
|
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||||
|
s.Require().NoError(err, "ExistsByName")
|
||||||
|
s.Require().True(exists)
|
||||||
|
|
||||||
|
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(notExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetAccountCount ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||||
|
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
|
||||||
|
|
||||||
|
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
|
s.Require().Equal(int64(2), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
|
||||||
|
|
||||||
|
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- DeleteAccountGroupsByGroupID ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||||
|
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
|
||||||
|
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
|
||||||
|
|
||||||
|
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||||
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
|
||||||
|
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err, "GetAccountCount")
|
||||||
|
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||||
|
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
|
||||||
|
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
|
||||||
|
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
|
||||||
|
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
|
||||||
|
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
|
||||||
|
|
||||||
|
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(int64(3), affected)
|
||||||
|
|
||||||
|
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||||
|
s.Require().Zero(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- DB ---
|
||||||
|
|
||||||
|
func (s *GroupRepoSuite) TestDB() {
|
||||||
|
db := s.repo.DB()
|
||||||
|
s.Require().NotNil(db, "DB should return non-nil")
|
||||||
|
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
|
||||||
|
}
|
||||||
115
backend/internal/repository/http_upstream_test.go
Normal file
115
backend/internal/repository/http_upstream_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPUpstreamSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||||
|
s.cfg = &config.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
svc, ok := up.(*httpUpstreamService)
|
||||||
|
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||||
|
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||||
|
require.True(s.T(), ok, "expected *http.Transport")
|
||||||
|
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||||
|
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
svc, ok := up.(*httpUpstreamService)
|
||||||
|
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||||
|
transport, ok := svc.defaultClient.Transport.(*http.Transport)
|
||||||
|
require.True(s.T(), ok, "expected *http.Transport")
|
||||||
|
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
|
||||||
|
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
svc, ok := up.(*httpUpstreamService)
|
||||||
|
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||||
|
|
||||||
|
got := svc.createProxyClient("://bad-proxy-url")
|
||||||
|
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "direct")
|
||||||
|
}))
|
||||||
|
s.T().Cleanup(upstream.Close)
|
||||||
|
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
|
||||||
|
require.NoError(s.T(), err, "NewRequest")
|
||||||
|
resp, err := up.Do(req, "")
|
||||||
|
require.NoError(s.T(), err, "Do")
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
require.Equal(s.T(), "direct", string(b), "unexpected body")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||||
|
seen := make(chan string, 1)
|
||||||
|
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
seen <- r.RequestURI
|
||||||
|
_, _ = io.WriteString(w, "proxied")
|
||||||
|
}))
|
||||||
|
s.T().Cleanup(proxySrv.Close)
|
||||||
|
|
||||||
|
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||||
|
require.NoError(s.T(), err, "NewRequest")
|
||||||
|
resp, err := up.Do(req, proxySrv.URL)
|
||||||
|
require.NoError(s.T(), err, "Do")
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
require.Equal(s.T(), "proxied", string(b), "unexpected body")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case uri := <-seen:
|
||||||
|
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
|
||||||
|
default:
|
||||||
|
require.Fail(s.T(), "expected proxy to receive request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||||
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "direct-empty")
|
||||||
|
}))
|
||||||
|
s.T().Cleanup(upstream.Close)
|
||||||
|
|
||||||
|
up := NewHTTPUpstream(s.cfg)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
|
||||||
|
require.NoError(s.T(), err, "NewRequest")
|
||||||
|
resp, err := up.Do(req, "")
|
||||||
|
require.NoError(s.T(), err, "Do with empty proxy")
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
require.Equal(s.T(), "direct-empty", string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPUpstreamSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(HTTPUpstreamSuite))
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service/ports"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IdentityCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache *identityCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||||
|
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||||
|
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||||
|
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||||
|
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||||
|
require.NoError(s.T(), err, "GetFingerprint")
|
||||||
|
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||||
|
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||||
|
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||||
|
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||||
|
|
||||||
|
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL fpKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||||
|
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||||
|
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||||
|
|
||||||
|
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||||
|
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||||
|
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||||
|
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||||
|
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIdentityCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(IdentityCacheSuite))
|
||||||
|
}
|
||||||
369
backend/internal/repository/integration_harness_test.go
Normal file
369
backend/internal/repository/integration_harness_test.go
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
redisclient "github.com/redis/go-redis/v9"
|
||||||
|
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||||
|
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||||
|
gormpostgres "gorm.io/driver/postgres"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
redisImageTag = "redis:8.4-alpine"
|
||||||
|
postgresImageTag = "postgres:18.1-alpine3.23"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
integrationDB *gorm.DB
|
||||||
|
integrationRedis *redisclient.Client
|
||||||
|
|
||||||
|
redisNamespaceSeq uint64
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if err := timezone.Init("UTC"); err != nil {
|
||||||
|
log.Printf("failed to init timezone: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dockerIsAvailable(ctx) {
|
||||||
|
// In CI we expect Docker to be available so integration tests should fail loudly.
|
||||||
|
if os.Getenv("CI") != "" {
|
||||||
|
log.Printf("docker is not available (CI=true); failing integration tests")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
postgresImage := selectDockerImage(ctx, postgresImageTag)
|
||||||
|
pgContainer, err := tcpostgres.Run(
|
||||||
|
ctx,
|
||||||
|
postgresImage,
|
||||||
|
tcpostgres.WithDatabase("sub2api_test"),
|
||||||
|
tcpostgres.WithUsername("postgres"),
|
||||||
|
tcpostgres.WithPassword("postgres"),
|
||||||
|
tcpostgres.BasicWaitStrategies(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to start postgres container: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer func() { _ = pgContainer.Terminate(ctx) }()
|
||||||
|
|
||||||
|
redisContainer, err := tcredis.Run(
|
||||||
|
ctx,
|
||||||
|
redisImageTag,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to start redis container: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer func() { _ = redisContainer.Terminate(ctx) }()
|
||||||
|
|
||||||
|
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to get postgres dsn: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to open gorm db: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
if err := model.AutoMigrate(integrationDB); err != nil {
|
||||||
|
log.Printf("failed to automigrate db: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
redisHost, err := redisContainer.Host(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to get redis host: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to get redis port: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
integrationRedis = redisclient.NewClient(&redisclient.Options{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||||
|
DB: 0,
|
||||||
|
})
|
||||||
|
if err := integrationRedis.Ping(ctx).Err(); err != nil {
|
||||||
|
log.Printf("failed to ping redis: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
|
||||||
|
_ = integrationRedis.Close()
|
||||||
|
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func dockerIsAvailable(ctx context.Context) bool {
|
||||||
|
cmd := exec.CommandContext(ctx, "docker", "info")
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
return cmd.Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectDockerImage(ctx context.Context, preferred string) string {
|
||||||
|
if dockerImageExists(ctx, preferred) {
|
||||||
|
return preferred
|
||||||
|
}
|
||||||
|
|
||||||
|
return preferred
|
||||||
|
}
|
||||||
|
|
||||||
|
func dockerImageExists(ctx context.Context, image string) bool {
|
||||||
|
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
|
||||||
|
cmd.Env = os.Environ()
|
||||||
|
cmd.Stdout = nil
|
||||||
|
cmd.Stderr = nil
|
||||||
|
return cmd.Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||||
|
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
return db.PingContext(pingCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testTx(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tx := integrationDB.Begin()
|
||||||
|
require.NoError(t, tx.Error, "begin tx")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = tx.Rollback().Error
|
||||||
|
})
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
func testRedis(t *testing.T) *redisclient.Client {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
prefix := fmt.Sprintf(
|
||||||
|
"it:%s:%d:%d:",
|
||||||
|
sanitizeRedisNamespace(t.Name()),
|
||||||
|
time.Now().UnixNano(),
|
||||||
|
atomic.AddUint64(&redisNamespaceSeq, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
opts := *integrationRedis.Options()
|
||||||
|
rdb := redisclient.NewClient(&opts)
|
||||||
|
rdb.AddHook(prefixHook{prefix: prefix})
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var cursor uint64
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
|
||||||
|
require.NoError(t, err, "scan redis keys for cleanup")
|
||||||
|
if len(keys) > 0 {
|
||||||
|
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = rdb.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return rdb
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
|
||||||
|
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeRedisNamespace(name string) string {
|
||||||
|
name = strings.ReplaceAll(name, "/", "_")
|
||||||
|
name = strings.ReplaceAll(name, " ", "_")
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
type prefixHook struct {
|
||||||
|
prefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
|
||||||
|
|
||||||
|
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
|
||||||
|
return func(ctx context.Context, cmd redisclient.Cmder) error {
|
||||||
|
h.prefixCmd(cmd)
|
||||||
|
return next(ctx, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
|
||||||
|
return func(ctx context.Context, cmds []redisclient.Cmder) error {
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
h.prefixCmd(cmd)
|
||||||
|
}
|
||||||
|
return next(ctx, cmds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
|
||||||
|
args := cmd.Args()
|
||||||
|
if len(args) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixOne := func(i int) {
|
||||||
|
if i < 0 || i >= len(args) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := args[i].(type) {
|
||||||
|
case string:
|
||||||
|
if v != "" && !strings.HasPrefix(v, h.prefix) {
|
||||||
|
args[i] = h.prefix + v
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
s := string(v)
|
||||||
|
if s != "" && !strings.HasPrefix(s, h.prefix) {
|
||||||
|
args[i] = []byte(h.prefix + s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(cmd.Name()) {
|
||||||
|
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
|
||||||
|
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
|
||||||
|
prefixOne(1)
|
||||||
|
case "del", "unlink":
|
||||||
|
for i := 1; i < len(args); i++ {
|
||||||
|
prefixOne(i)
|
||||||
|
}
|
||||||
|
case "eval", "evalsha", "eval_ro", "evalsha_ro":
|
||||||
|
if len(args) < 3 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
|
||||||
|
if err != nil || numKeys <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < numKeys && 3+i < len(args); i++ {
|
||||||
|
prefixOne(3 + i)
|
||||||
|
}
|
||||||
|
case "scan":
|
||||||
|
for i := 2; i+1 < len(args); i++ {
|
||||||
|
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
|
||||||
|
prefixOne(i + 1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationRedisSuite provides a base suite for Redis integration tests.
|
||||||
|
// Embedding suites should call SetupTest to initialize ctx and rdb.
|
||||||
|
type IntegrationRedisSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
rdb *redisclient.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupTest initializes ctx and rdb for each test method.
|
||||||
|
func (s *IntegrationRedisSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.rdb = testRedis(s.T())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||||
|
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||||
|
s.T().Helper()
|
||||||
|
require.NoError(s.T(), err, msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertTTLWithin asserts that ttl is within [min, max].
|
||||||
|
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||||
|
s.T().Helper()
|
||||||
|
assertTTLWithin(s.T(), ttl, min, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
|
||||||
|
// Embedding suites should call SetupTest to initialize ctx and db.
|
||||||
|
type IntegrationDBSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupTest initializes ctx and db for each test method.
|
||||||
|
func (s *IntegrationDBSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||||
|
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||||
|
s.T().Helper()
|
||||||
|
require.NoError(s.T(), err, msgAndArgs...)
|
||||||
|
}
|
||||||
@@ -12,11 +12,13 @@ import (
|
|||||||
"github.com/imroc/req/v3"
|
"github.com/imroc/req/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type openaiOAuthService struct{}
|
|
||||||
|
|
||||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||||
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
|
||||||
return &openaiOAuthService{}
|
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
type openaiOAuthService struct {
|
||||||
|
tokenURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
SetFormDataFromValues(formData).
|
SetFormDataFromValues(formData).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(openai.TokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
SetFormDataFromValues(formData).
|
SetFormDataFromValues(formData).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(openai.TokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
|||||||
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIOAuthServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
srv *httptest.Server
|
||||||
|
svc *openaiOAuthService
|
||||||
|
received chan url.Values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.received = make(chan url.Values, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||||
|
s.srv = httptest.NewServer(handler)
|
||||||
|
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||||
|
errCh := make(chan string, 1)
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
errCh <- "method mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
errCh <- "ParseForm failed"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
||||||
|
errCh <- "grant_type mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||||
|
errCh <- "client_id mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("code"); got != "code" {
|
||||||
|
errCh <- "code mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
||||||
|
errCh <- "redirect_uri mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
||||||
|
errCh <- "code_verifier mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
||||||
|
require.NoError(s.T(), err, "ExchangeCode")
|
||||||
|
select {
|
||||||
|
case msg := <-errCh:
|
||||||
|
require.Fail(s.T(), msg)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
require.Equal(s.T(), "at", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt", resp.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||||
|
errCh := make(chan string, 1)
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
errCh <- "ParseForm failed"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
||||||
|
errCh <- "grant_type mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
||||||
|
errCh <- "refresh_token mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||||
|
errCh <- "client_id mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
||||||
|
errCh <- "scope mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||||
|
require.NoError(s.T(), err, "RefreshToken")
|
||||||
|
select {
|
||||||
|
case msg := <-errCh:
|
||||||
|
require.Fail(s.T(), msg)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
require.Equal(s.T(), "at2", resp.AccessToken)
|
||||||
|
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = io.WriteString(w, "bad")
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "status 400")
|
||||||
|
require.ErrorContains(s.T(), err, "bad")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
s.srv.Close()
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "request failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||||
|
started := make(chan struct{})
|
||||||
|
block := make(chan struct{})
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(started)
|
||||||
|
<-block
|
||||||
|
}))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
cancel()
|
||||||
|
close(block)
|
||||||
|
|
||||||
|
err := <-done
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||||
|
want := "http://localhost:9999/cb"
|
||||||
|
errCh := make(chan string, 1)
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = r.ParseForm()
|
||||||
|
if got := r.PostForm.Get("redirect_uri"); got != want {
|
||||||
|
errCh <- "redirect_uri mismatch"
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
||||||
|
require.NoError(s.T(), err, "ExchangeCode")
|
||||||
|
select {
|
||||||
|
case msg := <-errCh:
|
||||||
|
require.Fail(s.T(), msg)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = r.ParseForm()
|
||||||
|
s.received <- r.PostForm
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||||
|
}))
|
||||||
|
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||||
|
require.NoError(s.T(), err, "ExchangeCode")
|
||||||
|
select {
|
||||||
|
case <-s.received:
|
||||||
|
default:
|
||||||
|
require.Fail(s.T(), "expected server to receive request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.WriteString(w, "not-valid-json")
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||||
|
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = io.WriteString(w, "unauthorized")
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||||
|
require.Error(s.T(), err, "expected error for non-2xx status")
|
||||||
|
require.ErrorContains(s.T(), err, "status 401")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||||
|
}
|
||||||
147
backend/internal/repository/pricing_service_test.go
Normal file
147
backend/internal/repository/pricing_service_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PricingServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
srv *httptest.Server
|
||||||
|
client *pricingRemoteClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.client = client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||||
|
s.srv = httptest.NewServer(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/ok" {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
|
||||||
|
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
||||||
|
require.NoError(s.T(), err, "FetchPricingJSON")
|
||||||
|
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
||||||
|
require.Error(s.T(), err, "expected error for non-200 status")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/hashfile":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
||||||
|
case "/hashonly":
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("def456\n"))
|
||||||
|
default:
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
||||||
|
require.NoError(s.T(), err, "FetchHashText")
|
||||||
|
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
||||||
|
|
||||||
|
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
||||||
|
require.NoError(s.T(), err, "FetchHashText")
|
||||||
|
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
||||||
|
require.Error(s.T(), err, "expected error for non-200 status")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
||||||
|
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
||||||
|
require.Error(s.T(), err, "expected error for invalid URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
// empty body
|
||||||
|
}))
|
||||||
|
|
||||||
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
||||||
|
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
||||||
|
require.Equal(s.T(), "", hash, "expected empty hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(" \n"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
||||||
|
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
||||||
|
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||||
|
started := make(chan struct{})
|
||||||
|
block := make(chan struct{})
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(started)
|
||||||
|
<-block
|
||||||
|
}))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(s.ctx)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-started
|
||||||
|
cancel()
|
||||||
|
close(block)
|
||||||
|
|
||||||
|
err := <-done
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPricingServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(PricingServiceSuite))
|
||||||
|
}
|
||||||
@@ -16,10 +16,14 @@ import (
|
|||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type proxyProbeService struct{}
|
|
||||||
|
|
||||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||||
return &proxyProbeService{}
|
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||||
|
|
||||||
|
type proxyProbeService struct {
|
||||||
|
ipInfoURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
}
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
121
backend/internal/repository/proxy_probe_service_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyProbeServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
proxySrv *httptest.Server
|
||||||
|
prober *proxyProbeService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||||
|
if s.proxySrv != nil {
|
||||||
|
s.proxySrv.Close()
|
||||||
|
s.proxySrv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||||
|
s.proxySrv = httptest.NewServer(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
|
||||||
|
_, err := createProxyTransport("://bad")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "invalid proxy URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
|
||||||
|
_, err := createProxyTransport("ftp://127.0.0.1:1")
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
|
||||||
|
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
|
||||||
|
require.NoError(s.T(), err, "createProxyTransport")
|
||||||
|
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||||
|
seen := make(chan string, 1)
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
seen <- r.RequestURI
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
||||||
|
}))
|
||||||
|
|
||||||
|
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.NoError(s.T(), err, "ProbeProxy")
|
||||||
|
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||||
|
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||||
|
require.Equal(s.T(), "c", info.City)
|
||||||
|
require.Equal(s.T(), "r", info.Region)
|
||||||
|
require.Equal(s.T(), "cc", info.Country)
|
||||||
|
|
||||||
|
// Verify proxy received the request
|
||||||
|
select {
|
||||||
|
case uri := <-seen:
|
||||||
|
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
||||||
|
default:
|
||||||
|
require.Fail(s.T(), "expected proxy to receive request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "status: 503")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, "not-json")
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.Error(s.T(), err)
|
||||||
|
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||||
|
s.prober.ipInfoURL = "://invalid-url"
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||||
|
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
s.proxySrv.Close()
|
||||||
|
|
||||||
|
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||||
|
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||||
|
}
|
||||||
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
302
backend/internal/repository/proxy_repo_integration_test.go
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProxyRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *ProxyRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewProxyRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(ProxyRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / Update / Delete ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestCreate() {
|
||||||
|
proxy := &model.Proxy{
|
||||||
|
Name: "test-create",
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: 8080,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, proxy)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(proxy.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("test-create", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestUpdate() {
|
||||||
|
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
|
||||||
|
|
||||||
|
proxy.Name = "updated"
|
||||||
|
err := s.repo.Update(s.ctx, proxy)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("updated", got.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestDelete() {
|
||||||
|
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, proxy.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestList() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||||
|
|
||||||
|
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(proxies, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
|
||||||
|
|
||||||
|
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(proxies, 1)
|
||||||
|
s.Require().Equal("socks5", proxies[0].Protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(proxies, 1)
|
||||||
|
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
|
||||||
|
|
||||||
|
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(proxies, 1)
|
||||||
|
s.Require().Contains(proxies[0].Name, "production")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListActive ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestListActive() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
proxies, err := s.repo.ListActive(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActive")
|
||||||
|
s.Require().Len(proxies, 1)
|
||||||
|
s.Require().Equal("active1", proxies[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ExistsByHostPortAuth ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p1",
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "1.2.3.4",
|
||||||
|
Port: 8080,
|
||||||
|
Username: "user",
|
||||||
|
Password: "pass",
|
||||||
|
})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||||
|
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||||
|
s.Require().True(exists)
|
||||||
|
|
||||||
|
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(notExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p-noauth",
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "5.6.7.8",
|
||||||
|
Port: 8081,
|
||||||
|
Username: "",
|
||||||
|
Password: "",
|
||||||
|
})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- CountAccountsByProxyID ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||||
|
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
|
||||||
|
|
||||||
|
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||||
|
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||||
|
s.Require().Equal(int64(2), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||||
|
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
|
||||||
|
|
||||||
|
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetAccountCountsForProxies ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||||
|
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
|
||||||
|
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
|
||||||
|
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||||
|
|
||||||
|
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||||
|
s.Require().Equal(int64(2), counts[p1.ID])
|
||||||
|
s.Require().Equal(int64(1), counts[p2.ID])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||||
|
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Empty(counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListActiveWithAccountCount ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||||
|
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p1",
|
||||||
|
Status: model.StatusActive,
|
||||||
|
CreatedAt: base.Add(-1 * time.Hour),
|
||||||
|
})
|
||||||
|
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p2",
|
||||||
|
Status: model.StatusActive,
|
||||||
|
CreatedAt: base,
|
||||||
|
})
|
||||||
|
mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p3-inactive",
|
||||||
|
Status: model.StatusDisabled,
|
||||||
|
})
|
||||||
|
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||||
|
|
||||||
|
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||||
|
s.Require().Len(withCounts, 2, "expected 2 active proxies")
|
||||||
|
|
||||||
|
// Sorted by created_at DESC, so p2 first
|
||||||
|
s.Require().Equal(p2.ID, withCounts[0].ID)
|
||||||
|
s.Require().Equal(int64(1), withCounts[0].AccountCount)
|
||||||
|
s.Require().Equal(p1.ID, withCounts[1].ID)
|
||||||
|
s.Require().Equal(int64(2), withCounts[1].AccountCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Combined original test ---
|
||||||
|
|
||||||
|
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||||
|
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p1",
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "1.2.3.4",
|
||||||
|
Port: 8080,
|
||||||
|
Username: "u",
|
||||||
|
Password: "p",
|
||||||
|
CreatedAt: time.Now().Add(-1 * time.Hour),
|
||||||
|
UpdatedAt: time.Now().Add(-1 * time.Hour),
|
||||||
|
})
|
||||||
|
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
|
||||||
|
Name: "p2",
|
||||||
|
Protocol: "http",
|
||||||
|
Host: "5.6.7.8",
|
||||||
|
Port: 8081,
|
||||||
|
Username: "",
|
||||||
|
Password: "",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||||
|
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||||
|
s.Require().True(exists, "expected proxy to exist")
|
||||||
|
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
|
||||||
|
|
||||||
|
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||||
|
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||||
|
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
|
||||||
|
|
||||||
|
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||||
|
s.Require().Equal(int64(2), counts[p1.ID])
|
||||||
|
s.Require().Equal(int64(1), counts[p2.ID])
|
||||||
|
|
||||||
|
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||||
|
s.Require().Len(withCounts, 2, "expected 2 proxies")
|
||||||
|
for _, pc := range withCounts {
|
||||||
|
switch pc.ID {
|
||||||
|
case p1.ID:
|
||||||
|
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
|
||||||
|
case p2.ID:
|
||||||
|
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
|
||||||
|
default:
|
||||||
|
s.Require().Fail("unexpected proxy id", pc.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
105
backend/internal/repository/redeem_cache_integration_test.go
Normal file
105
backend/internal/repository/redeem_cache_integration_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RedeemCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache *redeemCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
|
||||||
|
missingUserID := int64(99999)
|
||||||
|
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||||
|
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
|
||||||
|
userID := int64(1)
|
||||||
|
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
|
||||||
|
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||||
|
require.NoError(s.T(), err, "GetRedeemAttemptCount")
|
||||||
|
require.Equal(s.T(), 1, count, "count mismatch")
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, key).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestMultipleIncrements() {
|
||||||
|
userID := int64(2)
|
||||||
|
|
||||||
|
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||||
|
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||||
|
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||||
|
|
||||||
|
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 3, count, "count after 3 increments")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
|
||||||
|
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||||
|
require.NoError(s.T(), err, "AcquireRedeemLock")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
// Second acquire should fail
|
||||||
|
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||||
|
require.NoError(s.T(), err, "AcquireRedeemLock 2")
|
||||||
|
require.False(s.T(), ok, "expected lock to be held")
|
||||||
|
|
||||||
|
// Release
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
|
||||||
|
|
||||||
|
// Now acquire should succeed
|
||||||
|
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||||
|
require.NoError(s.T(), err, "AcquireRedeemLock after release")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
|
||||||
|
lockKey := redeemLockKeyPrefix + "CODE2"
|
||||||
|
lockTTL := 15 * time.Second
|
||||||
|
|
||||||
|
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
|
||||||
|
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL lock key")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
|
||||||
|
// Release a lock that doesn't exist should not error
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
|
||||||
|
|
||||||
|
// Acquire, release, release again
|
||||||
|
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), ok)
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
|
||||||
|
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedeemCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(RedeemCacheSuite))
|
||||||
|
}
|
||||||
315
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
315
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RedeemCodeRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *RedeemCodeRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewRedeemCodeRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(RedeemCodeRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||||
|
code := &model.RedeemCode{
|
||||||
|
Code: "TEST-CREATE",
|
||||||
|
Type: model.RedeemTypeBalance,
|
||||||
|
Value: 100,
|
||||||
|
Status: model.StatusUnused,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, code)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(code.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("TEST-CREATE", got.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||||
|
codes := []model.RedeemCode{
|
||||||
|
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
|
||||||
|
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.CreateBatch(s.ctx, codes)
|
||||||
|
s.Require().NoError(err, "CreateBatch")
|
||||||
|
|
||||||
|
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(float64(10), got1.Value)
|
||||||
|
|
||||||
|
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(float64(20), got2.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||||
|
s.Require().NoError(err, "GetByCode")
|
||||||
|
s.Require().Equal("GET-BY-CODE", got.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||||
|
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
|
||||||
|
s.Require().Error(err, "expected error for non-existent code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Delete ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||||
|
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, code.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, code.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestList() {
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
|
||||||
|
|
||||||
|
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(codes, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
|
||||||
|
|
||||||
|
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||||
|
|
||||||
|
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
s.Require().Equal(model.StatusUsed, codes[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
|
||||||
|
|
||||||
|
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
s.Require().Contains(codes[0].Code, "ALPHA")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||||
|
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||||
|
Code: "WITH-GROUP",
|
||||||
|
Type: model.RedeemTypeSubscription,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
s.Require().NotNil(codes[0].Group, "expected Group preload")
|
||||||
|
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Update ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||||
|
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
|
||||||
|
|
||||||
|
code.Value = 50
|
||||||
|
err := s.repo.Update(s.ctx, code)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(float64(50), got.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Use ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
|
||||||
|
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||||
|
|
||||||
|
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||||
|
s.Require().NoError(err, "Use")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(model.StatusUsed, got.Status)
|
||||||
|
s.Require().NotNil(got.UsedBy)
|
||||||
|
s.Require().Equal(user.ID, *got.UsedBy)
|
||||||
|
s.Require().NotNil(got.UsedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
|
||||||
|
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
|
||||||
|
|
||||||
|
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||||
|
s.Require().NoError(err, "Use first time")
|
||||||
|
|
||||||
|
// Second use should fail
|
||||||
|
err = s.repo.Use(s.ctx, code.ID, user.ID)
|
||||||
|
s.Require().Error(err, "Use expected error on second call")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
|
||||||
|
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
|
||||||
|
|
||||||
|
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||||
|
s.Require().Error(err, "expected error for already used code")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByUser ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||||
|
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
// Create codes with explicit used_at for ordering
|
||||||
|
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||||
|
Code: "USER-1",
|
||||||
|
Type: model.RedeemTypeBalance,
|
||||||
|
Status: model.StatusUsed,
|
||||||
|
UsedBy: &user.ID,
|
||||||
|
})
|
||||||
|
s.db.Model(c1).Update("used_at", base)
|
||||||
|
|
||||||
|
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||||
|
Code: "USER-2",
|
||||||
|
Type: model.RedeemTypeBalance,
|
||||||
|
Status: model.StatusUsed,
|
||||||
|
UsedBy: &user.ID,
|
||||||
|
})
|
||||||
|
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||||
|
s.Require().NoError(err, "ListByUser")
|
||||||
|
s.Require().Len(codes, 2)
|
||||||
|
// Ordered by used_at DESC, so USER-2 first
|
||||||
|
s.Require().Equal("USER-2", codes[0].Code)
|
||||||
|
s.Require().Equal("USER-1", codes[1].Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
|
||||||
|
|
||||||
|
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||||
|
Code: "WITH-GRP",
|
||||||
|
Type: model.RedeemTypeSubscription,
|
||||||
|
Status: model.StatusUsed,
|
||||||
|
UsedBy: &user.ID,
|
||||||
|
GroupID: &group.ID,
|
||||||
|
})
|
||||||
|
s.db.Model(c).Update("used_at", time.Now())
|
||||||
|
|
||||||
|
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
s.Require().NotNil(codes[0].Group)
|
||||||
|
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
|
||||||
|
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
|
||||||
|
Code: "DEF-LIM",
|
||||||
|
Type: model.RedeemTypeBalance,
|
||||||
|
Status: model.StatusUsed,
|
||||||
|
UsedBy: &user.ID,
|
||||||
|
})
|
||||||
|
s.db.Model(c).Update("used_at", time.Now())
|
||||||
|
|
||||||
|
// limit <= 0 should default to 10
|
||||||
|
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(codes, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Combined original test ---
|
||||||
|
|
||||||
|
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
|
||||||
|
|
||||||
|
codes := []model.RedeemCode{
|
||||||
|
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
|
||||||
|
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||||
|
|
||||||
|
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
s.Require().Len(list, 1)
|
||||||
|
s.Require().NotNil(list[0].Group, "expected Group preload")
|
||||||
|
s.Require().Equal(group.ID, list[0].Group.ID)
|
||||||
|
|
||||||
|
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
|
||||||
|
s.Require().NoError(err, "GetByCode")
|
||||||
|
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
|
||||||
|
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
|
||||||
|
s.Require().Error(err, "Use expected error on second call")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||||
|
s.Require().NoError(err, "GetByCode")
|
||||||
|
|
||||||
|
// Use fixed time instead of time.Sleep for deterministic ordering
|
||||||
|
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
|
||||||
|
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||||
|
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
|
||||||
|
|
||||||
|
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||||
|
s.Require().NoError(err, "ListByUser")
|
||||||
|
s.Require().Len(used, 2, "expected 2 used codes")
|
||||||
|
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
|
||||||
|
}
|
||||||
108
backend/internal/repository/setting_repo_integration_test.go
Normal file
108
backend/internal/repository/setting_repo_integration_test.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SettingRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *SettingRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewSettingRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSettingRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(SettingRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestSetAndGetValue() {
|
||||||
|
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||||
|
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||||
|
s.Require().NoError(err, "GetValue")
|
||||||
|
s.Require().Equal("v1", got, "GetValue mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestSet_Upsert() {
|
||||||
|
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||||
|
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
|
||||||
|
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||||
|
s.Require().NoError(err, "GetValue after upsert")
|
||||||
|
s.Require().Equal("v2", got, "upsert mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestGetValue_Missing() {
|
||||||
|
_, err := s.repo.GetValue(s.ctx, "nonexistent")
|
||||||
|
s.Require().Error(err, "expected error for missing key")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
|
||||||
|
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
|
||||||
|
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
|
||||||
|
s.Require().NoError(err, "GetMultiple")
|
||||||
|
s.Require().Equal("v2", m["k2"])
|
||||||
|
s.Require().Equal("v3", m["k3"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
|
||||||
|
m, err := s.repo.GetMultiple(s.ctx, []string{})
|
||||||
|
s.Require().NoError(err, "GetMultiple with empty keys")
|
||||||
|
s.Require().Empty(m, "expected empty map")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
|
||||||
|
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
|
||||||
|
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
|
||||||
|
s.Require().NoError(err, "GetMultiple subset")
|
||||||
|
s.Require().Equal("1", m["a"])
|
||||||
|
s.Require().Equal("3", m["c"])
|
||||||
|
_, exists := m["nonexistent"]
|
||||||
|
s.Require().False(exists, "nonexistent key should not be in map")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestGetAll() {
|
||||||
|
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
|
||||||
|
all, err := s.repo.GetAll(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetAll")
|
||||||
|
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
|
||||||
|
s.Require().Equal("1", all["x"])
|
||||||
|
s.Require().Equal("2", all["y"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestDelete() {
|
||||||
|
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
|
||||||
|
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
|
||||||
|
_, err := s.repo.GetValue(s.ctx, "todelete")
|
||||||
|
s.Require().Error(err, "expected missing key error after Delete")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestDelete_Idempotent() {
|
||||||
|
// Delete a key that doesn't exist should not error
|
||||||
|
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
|
||||||
|
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
|
||||||
|
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
|
||||||
|
|
||||||
|
got, err := s.repo.GetValue(s.ctx, "upsert_key")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
|
||||||
|
|
||||||
|
got2, err := s.repo.GetValue(s.ctx, "new_key")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("new_val", got2)
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
|
|||||||
|
|
||||||
type turnstileVerifier struct {
|
type turnstileVerifier struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
verifyURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||||
@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
},
|
},
|
||||||
|
verifyURL: turnstileVerifyURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
|||||||
formData.Set("remoteip", remoteIP)
|
formData.Set("remoteip", remoteIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
143
backend/internal/repository/turnstile_service_test.go
Normal file
143
backend/internal/repository/turnstile_service_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TurnstileServiceSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
srv *httptest.Server
|
||||||
|
verifier *turnstileVerifier
|
||||||
|
received chan url.Values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.received = make(chan url.Values, 1)
|
||||||
|
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
|
||||||
|
require.True(s.T(), ok, "type assertion failed")
|
||||||
|
s.verifier = verifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TearDownTest() {
|
||||||
|
if s.srv != nil {
|
||||||
|
s.srv.Close()
|
||||||
|
s.srv = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||||
|
s.srv = httptest.NewServer(handler)
|
||||||
|
s.verifier.verifyURL = s.srv.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Capture form data in main goroutine context later
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
values, _ := url.ParseQuery(string(body))
|
||||||
|
s.received <- values
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||||
|
require.NoError(s.T(), err, "VerifyToken")
|
||||||
|
require.NotNil(s.T(), resp)
|
||||||
|
require.True(s.T(), resp.Success, "expected success response")
|
||||||
|
|
||||||
|
// Assert form fields in main goroutine
|
||||||
|
select {
|
||||||
|
case values := <-s.received:
|
||||||
|
require.Equal(s.T(), "sk", values.Get("secret"))
|
||||||
|
require.Equal(s.T(), "token", values.Get("response"))
|
||||||
|
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
|
||||||
|
default:
|
||||||
|
require.Fail(s.T(), "expected server to receive request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||||
|
var contentType string
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
contentType = r.Header.Get("Content-Type")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
values, _ := url.ParseQuery(string(body))
|
||||||
|
s.received <- values
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case values := <-s.received:
|
||||||
|
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
|
||||||
|
default:
|
||||||
|
require.Fail(s.T(), "expected server to receive request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
s.srv.Close()
|
||||||
|
|
||||||
|
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||||
|
require.Error(s.T(), err, "expected error when server is closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = io.WriteString(w, "not-valid-json")
|
||||||
|
}))
|
||||||
|
|
||||||
|
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||||
|
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
|
||||||
|
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
|
||||||
|
Success: false,
|
||||||
|
ErrorCodes: []string{"invalid-input-response"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||||
|
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
|
||||||
|
require.NotNil(s.T(), resp)
|
||||||
|
require.False(s.T(), resp.Success)
|
||||||
|
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTurnstileServiceSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(TurnstileServiceSuite))
|
||||||
|
}
|
||||||
73
backend/internal/repository/update_cache_integration_test.go
Normal file
73
backend/internal/repository/update_cache_integration_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UpdateCacheSuite struct {
|
||||||
|
IntegrationRedisSuite
|
||||||
|
cache *updateCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) SetupTest() {
|
||||||
|
s.IntegrationRedisSuite.SetupTest()
|
||||||
|
s.cache = NewUpdateCache(s.rdb).(*updateCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
|
||||||
|
_, err := s.cache.GetUpdateInfo(s.ctx)
|
||||||
|
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
|
||||||
|
updateTTL := 5 * time.Minute
|
||||||
|
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
|
||||||
|
|
||||||
|
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||||
|
require.NoError(s.T(), err, "GetUpdateInfo")
|
||||||
|
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
|
||||||
|
updateTTL := 5 * time.Minute
|
||||||
|
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||||
|
require.NoError(s.T(), err, "TTL updateCacheKey")
|
||||||
|
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
|
||||||
|
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
|
||||||
|
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
|
||||||
|
|
||||||
|
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
|
||||||
|
// TTL=0 means persist forever (no expiry) in Redis SET command
|
||||||
|
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
|
||||||
|
|
||||||
|
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), "v0.0.0", info)
|
||||||
|
|
||||||
|
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
|
||||||
|
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateCacheSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(UpdateCacheSuite))
|
||||||
|
}
|
||||||
890
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
890
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
@@ -0,0 +1,890 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UsageLogRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *UsageLogRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewUsageLogRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(UsageLogRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
|
||||||
|
log := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
TotalCost: cost,
|
||||||
|
ActualCost: cost,
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log))
|
||||||
|
return log
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestCreate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
|
||||||
|
|
||||||
|
log := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, log)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(log.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetByID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
|
||||||
|
|
||||||
|
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(log.ID, got.ID)
|
||||||
|
s.Require().Equal(10, got.InputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Delete ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDelete() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
|
||||||
|
|
||||||
|
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, log.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, log.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByUser ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByUser() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
|
logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByUser")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByApiKey ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByApiKey() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
|
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByApiKey")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByAccount ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByAccount() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByAccount")
|
||||||
|
s.Require().Len(logs, 1)
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUserStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "GetUserStats")
|
||||||
|
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||||
|
s.Require().Equal(int64(25), stats.InputTokens)
|
||||||
|
s.Require().Equal(int64(45), stats.OutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
filters := usagestats.UsageLogFilters{UserID: user.ID}
|
||||||
|
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Len(logs, 1)
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetDashboardStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||||
|
now := time.Now()
|
||||||
|
todayStart := timezone.Today()
|
||||||
|
|
||||||
|
userToday := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "today@example.com",
|
||||||
|
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||||
|
UpdatedAt: now,
|
||||||
|
})
|
||||||
|
userOld := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "old@example.com",
|
||||||
|
CreatedAt: todayStart.Add(-24 * time.Hour),
|
||||||
|
UpdatedAt: todayStart.Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
|
||||||
|
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
resetAt := now.Add(10 * time.Minute)
|
||||||
|
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
|
||||||
|
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
|
||||||
|
|
||||||
|
d1, d2, d3 := 100, 200, 300
|
||||||
|
logToday := &model.UsageLog{
|
||||||
|
UserID: userToday.ID,
|
||||||
|
ApiKeyID: apiKey1.ID,
|
||||||
|
AccountID: accNormal.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
GroupID: &group.ID,
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
CacheCreationTokens: 3,
|
||||||
|
CacheReadTokens: 4,
|
||||||
|
TotalCost: 1.5,
|
||||||
|
ActualCost: 1.2,
|
||||||
|
DurationMs: &d1,
|
||||||
|
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
|
||||||
|
|
||||||
|
logOld := &model.UsageLog{
|
||||||
|
UserID: userOld.ID,
|
||||||
|
ApiKeyID: apiKey1.ID,
|
||||||
|
AccountID: accNormal.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 5,
|
||||||
|
OutputTokens: 6,
|
||||||
|
TotalCost: 0.7,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
DurationMs: &d2,
|
||||||
|
CreatedAt: todayStart.Add(-1 * time.Hour),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
|
||||||
|
|
||||||
|
logPerf := &model.UsageLog{
|
||||||
|
UserID: userToday.ID,
|
||||||
|
ApiKeyID: apiKey1.ID,
|
||||||
|
AccountID: accNormal.ID,
|
||||||
|
Model: "claude-3",
|
||||||
|
InputTokens: 1,
|
||||||
|
OutputTokens: 2,
|
||||||
|
TotalCost: 0.1,
|
||||||
|
ActualCost: 0.1,
|
||||||
|
DurationMs: &d3,
|
||||||
|
CreatedAt: now.Add(-30 * time.Second),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
|
||||||
|
|
||||||
|
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetDashboardStats")
|
||||||
|
|
||||||
|
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
|
||||||
|
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
|
||||||
|
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
|
||||||
|
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
|
||||||
|
|
||||||
|
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
|
||||||
|
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
|
||||||
|
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
|
||||||
|
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
|
||||||
|
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
|
||||||
|
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||||
|
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
|
||||||
|
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||||
|
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||||
|
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||||
|
|
||||||
|
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
|
||||||
|
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
|
||||||
|
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUserDashboardStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "GetUserDashboardStats")
|
||||||
|
s.Require().Equal(int64(1), stats.TotalApiKeys)
|
||||||
|
s.Require().Equal(int64(1), stats.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetAccountTodayStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
||||||
|
s.Require().NoError(err, "GetAccountTodayStats")
|
||||||
|
s.Require().Equal(int64(1), stats.Requests)
|
||||||
|
s.Require().Equal(int64(30), stats.Tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetBatchUserUsageStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
|
||||||
|
|
||||||
|
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||||
|
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
|
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
||||||
|
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
s.Require().NotNil(stats[user1.ID])
|
||||||
|
s.Require().NotNil(stats[user2.ID])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||||
|
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Empty(stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetBatchApiKeyUsageStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||||
|
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
|
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||||
|
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||||
|
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Empty(stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetGlobalStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour))
|
||||||
|
s.Require().NoError(err, "GetGlobalStats")
|
||||||
|
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||||
|
s.Require().Equal(int64(25), stats.TotalInputTokens)
|
||||||
|
s.Require().Equal(int64(45), stats.TotalOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxTime(a, b time.Time) time.Time {
|
||||||
|
if a.After(b) {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByUserAndTimeRange ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "ListByUserAndTimeRange")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByApiKeyAndTimeRange ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByAccountAndTimeRange ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "ListByAccountAndTimeRange")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByModelAndTimeRange ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
// Create logs with different models
|
||||||
|
log1 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: base,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||||
|
|
||||||
|
log2 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
InputTokens: 15,
|
||||||
|
OutputTokens: 25,
|
||||||
|
TotalCost: 0.6,
|
||||||
|
ActualCost: 0.6,
|
||||||
|
CreatedAt: base.Add(30 * time.Minute),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||||
|
|
||||||
|
log3 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-sonnet",
|
||||||
|
InputTokens: 20,
|
||||||
|
OutputTokens: 30,
|
||||||
|
TotalCost: 0.7,
|
||||||
|
ActualCost: 0.7,
|
||||||
|
CreatedAt: base.Add(1 * time.Hour),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log3))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime)
|
||||||
|
s.Require().NoError(err, "ListByModelAndTimeRange")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetAccountWindowStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-10 * time.Minute)
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute))
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window
|
||||||
|
|
||||||
|
stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart)
|
||||||
|
s.Require().NoError(err, "GetAccountWindowStats")
|
||||||
|
s.Require().Equal(int64(2), stats.Requests)
|
||||||
|
s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUserUsageTrendByUserID ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day")
|
||||||
|
s.Require().NoError(err, "GetUserUsageTrendByUserID")
|
||||||
|
s.Require().Len(trend, 2) // 2 different days
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(3 * time.Hour)
|
||||||
|
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour")
|
||||||
|
s.Require().NoError(err, "GetUserUsageTrendByUserID hourly")
|
||||||
|
s.Require().Len(trend, 3) // 3 different hours
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUserModelStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
// Create logs with different models
|
||||||
|
log1 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 200,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: base,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||||
|
|
||||||
|
log2 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-sonnet",
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 100,
|
||||||
|
TotalCost: 0.2,
|
||||||
|
ActualCost: 0.2,
|
||||||
|
CreatedAt: base.Add(1 * time.Hour),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "GetUserModelStats")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
|
// Should be ordered by total_tokens DESC
|
||||||
|
s.Require().Equal("claude-3-opus", stats[0].Model)
|
||||||
|
s.Require().Equal(int64(300), stats[0].TotalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUsageTrendWithFilters ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
|
||||||
|
// Test with user filter
|
||||||
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
|
||||||
|
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||||
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
|
// Test with apiKey filter
|
||||||
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
|
||||||
|
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||||
|
s.Require().Len(trend, 2)
|
||||||
|
|
||||||
|
// Test with both filters
|
||||||
|
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
|
||||||
|
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||||
|
s.Require().Len(trend, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(3 * time.Hour)
|
||||||
|
|
||||||
|
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
|
||||||
|
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||||
|
s.Require().Len(trend, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetModelStatsWithFilters ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
log1 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 200,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.5,
|
||||||
|
CreatedAt: base,
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||||
|
|
||||||
|
log2 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-sonnet",
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 100,
|
||||||
|
TotalCost: 0.2,
|
||||||
|
ActualCost: 0.2,
|
||||||
|
CreatedAt: base.Add(1 * time.Hour),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
|
||||||
|
// Test with user filter
|
||||||
|
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
|
||||||
|
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
|
// Test with apiKey filter
|
||||||
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
|
||||||
|
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
|
||||||
|
// Test with account filter
|
||||||
|
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
|
||||||
|
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||||
|
s.Require().Len(stats, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetAccountUsageStats ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
// Create logs on different days
|
||||||
|
log1 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-opus",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 200,
|
||||||
|
TotalCost: 0.5,
|
||||||
|
ActualCost: 0.4,
|
||||||
|
CreatedAt: base.Add(12 * time.Hour),
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log1))
|
||||||
|
|
||||||
|
log2 := &model.UsageLog{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
AccountID: account.ID,
|
||||||
|
Model: "claude-3-sonnet",
|
||||||
|
InputTokens: 50,
|
||||||
|
OutputTokens: 100,
|
||||||
|
TotalCost: 0.2,
|
||||||
|
ActualCost: 0.15,
|
||||||
|
CreatedAt: base.Add(36 * time.Hour), // next day
|
||||||
|
}
|
||||||
|
s.Require().NoError(s.repo.Create(s.ctx, log2))
|
||||||
|
|
||||||
|
startTime := base
|
||||||
|
endTime := base.Add(72 * time.Hour)
|
||||||
|
|
||||||
|
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "GetAccountUsageStats")
|
||||||
|
|
||||||
|
s.Require().Len(resp.History, 2, "expected 2 days of history")
|
||||||
|
s.Require().Equal(int64(2), resp.Summary.TotalRequests)
|
||||||
|
s.Require().Equal(int64(450), resp.Summary.TotalTokens)
|
||||||
|
s.Require().Len(resp.Models, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||||
|
startTime := base
|
||||||
|
endTime := base.Add(72 * time.Hour)
|
||||||
|
|
||||||
|
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
|
||||||
|
s.Require().NoError(err, "GetAccountUsageStats empty")
|
||||||
|
|
||||||
|
s.Require().Len(resp.History, 0)
|
||||||
|
s.Require().Equal(int64(0), resp.Summary.TotalRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetUserUsageTrend ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
|
||||||
|
s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base)
|
||||||
|
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
|
||||||
|
trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||||
|
s.Require().NoError(err, "GetUserUsageTrend")
|
||||||
|
s.Require().GreaterOrEqual(len(trend), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetApiKeyUsageTrend ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
|
||||||
|
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
|
||||||
|
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
|
||||||
|
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(48 * time.Hour)
|
||||||
|
|
||||||
|
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
|
||||||
|
s.Require().NoError(err, "GetApiKeyUsageTrend")
|
||||||
|
s.Require().GreaterOrEqual(len(trend), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(3 * time.Hour)
|
||||||
|
|
||||||
|
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
|
||||||
|
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
|
||||||
|
s.Require().Len(trend, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListWithFilters (additional filter tests) ---
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
|
||||||
|
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||||
|
|
||||||
|
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
|
||||||
|
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||||
|
s.Require().NoError(err, "ListWithFilters apiKey")
|
||||||
|
s.Require().Len(logs, 1)
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime}
|
||||||
|
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||||
|
s.Require().NoError(err, "ListWithFilters time range")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
|
||||||
|
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
|
||||||
|
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
|
||||||
|
|
||||||
|
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||||
|
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
|
||||||
|
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
|
||||||
|
|
||||||
|
startTime := base.Add(-1 * time.Hour)
|
||||||
|
endTime := base.Add(2 * time.Hour)
|
||||||
|
filters := usagestats.UsageLogFilters{
|
||||||
|
UserID: user.ID,
|
||||||
|
ApiKeyID: apiKey.ID,
|
||||||
|
StartTime: &startTime,
|
||||||
|
EndTime: &endTime,
|
||||||
|
}
|
||||||
|
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
|
||||||
|
s.Require().NoError(err, "ListWithFilters combined")
|
||||||
|
s.Require().Len(logs, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
448
backend/internal/repository/user_repo_integration_test.go
Normal file
448
backend/internal/repository/user_repo_integration_test.go
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *UserRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewUserRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(UserRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestCreate() {
|
||||||
|
user := &model.User{
|
||||||
|
Email: "create@test.com",
|
||||||
|
Username: "testuser",
|
||||||
|
Role: model.RoleUser,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, user)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(user.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal("create@test.com", got.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetByEmail() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||||
|
s.Require().NoError(err, "GetByEmail")
|
||||||
|
s.Require().Equal(user.ID, got.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||||
|
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
|
||||||
|
s.Require().Error(err, "expected error for non-existent email")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
|
||||||
|
|
||||||
|
user.Username = "updated"
|
||||||
|
err := s.repo.Update(s.ctx, user)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("updated", got.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestDelete() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List / ListWithFilters ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestList() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
|
||||||
|
|
||||||
|
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(users, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(users, 1)
|
||||||
|
s.Require().Equal(model.StatusActive, users[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(users, 1)
|
||||||
|
s.Require().Equal(model.RoleAdmin, users[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(users, 1)
|
||||||
|
s.Require().Contains(users[0].Email, "alice")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(users, 1)
|
||||||
|
s.Require().Equal("JohnDoe", users[0].Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(users, 1)
|
||||||
|
s.Require().Equal("wx_hello", users[0].Wechat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
|
||||||
|
|
||||||
|
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
})
|
||||||
|
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Len(users, 1, "expected 1 user")
|
||||||
|
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||||
|
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
|
||||||
|
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "a@example.com",
|
||||||
|
Username: "Alice",
|
||||||
|
Wechat: "wx_a",
|
||||||
|
Role: model.RoleUser,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
Balance: 10,
|
||||||
|
})
|
||||||
|
target := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "b@example.com",
|
||||||
|
Username: "Bob",
|
||||||
|
Wechat: "wx_b",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
Balance: 1,
|
||||||
|
})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "c@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusDisabled,
|
||||||
|
})
|
||||||
|
|
||||||
|
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||||
|
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||||
|
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Balance operations ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
|
||||||
|
|
||||||
|
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||||
|
s.Require().NoError(err, "UpdateBalance")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(12.5, got.Balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
|
||||||
|
|
||||||
|
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||||
|
s.Require().NoError(err, "UpdateBalance with negative")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(7.0, got.Balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestDeductBalance() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
|
||||||
|
|
||||||
|
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||||
|
s.Require().NoError(err, "DeductBalance")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(5.0, got.Balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
|
||||||
|
|
||||||
|
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||||
|
s.Require().Error(err, "expected error for insufficient balance")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
|
||||||
|
|
||||||
|
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||||
|
s.Require().NoError(err, "DeductBalance exact amount")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(got.Balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Concurrency ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
|
||||||
|
|
||||||
|
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||||
|
s.Require().NoError(err, "UpdateConcurrency")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(8, got.Concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
|
||||||
|
|
||||||
|
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||||
|
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(3, got.Concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ExistsByEmail ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||||
|
s.Require().NoError(err, "ExistsByEmail")
|
||||||
|
s.Require().True(exists)
|
||||||
|
|
||||||
|
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(notExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RemoveGroupFromAllowedGroups ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||||
|
groupID := int64(42)
|
||||||
|
userA := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "a1@example.com",
|
||||||
|
AllowedGroups: pq.Int64Array{groupID, 7},
|
||||||
|
})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "a2@example.com",
|
||||||
|
AllowedGroups: pq.Int64Array{7},
|
||||||
|
})
|
||||||
|
|
||||||
|
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
|
||||||
|
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
|
||||||
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, userA.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
for _, id := range got.AllowedGroups {
|
||||||
|
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "nomatch@test.com",
|
||||||
|
AllowedGroups: pq.Int64Array{1, 2, 3},
|
||||||
|
})
|
||||||
|
|
||||||
|
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(affected, "expected no affected rows")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetFirstAdmin ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||||
|
admin1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "admin1@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "admin2@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetFirstAdmin")
|
||||||
|
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "user@example.com",
|
||||||
|
Role: model.RoleUser,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := s.repo.GetFirstAdmin(s.ctx)
|
||||||
|
s.Require().Error(err, "expected error when no admin exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||||
|
mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "disabled@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusDisabled,
|
||||||
|
})
|
||||||
|
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "active@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||||
|
s.Require().NoError(err, "GetFirstAdmin")
|
||||||
|
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Combined original test ---
|
||||||
|
|
||||||
|
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "a@example.com",
|
||||||
|
Username: "Alice",
|
||||||
|
Wechat: "wx_a",
|
||||||
|
Role: model.RoleUser,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
Balance: 10,
|
||||||
|
})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "b@example.com",
|
||||||
|
Username: "Bob",
|
||||||
|
Wechat: "wx_b",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusActive,
|
||||||
|
Balance: 1,
|
||||||
|
})
|
||||||
|
_ = mustCreateUser(s.T(), s.db, &model.User{
|
||||||
|
Email: "c@example.com",
|
||||||
|
Role: model.RoleAdmin,
|
||||||
|
Status: model.StatusDisabled,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
|
||||||
|
|
||||||
|
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
|
||||||
|
s.Require().NoError(err, "GetByEmail")
|
||||||
|
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
|
||||||
|
|
||||||
|
got.Username = "Alice2"
|
||||||
|
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||||
|
got2, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
|
||||||
|
got3, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after UpdateBalance")
|
||||||
|
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
|
||||||
|
got4, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after DeductBalance")
|
||||||
|
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
|
||||||
|
|
||||||
|
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||||
|
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
|
||||||
|
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||||
|
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after UpdateConcurrency")
|
||||||
|
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
|
||||||
|
|
||||||
|
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||||
|
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
|
||||||
|
s.Require().NoError(err, "ListWithFilters")
|
||||||
|
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||||
|
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||||
|
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||||
|
}
|
||||||
@@ -0,0 +1,733 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserSubscriptionRepoSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
ctx context.Context
|
||||||
|
db *gorm.DB
|
||||||
|
repo *UserSubscriptionRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||||
|
s.ctx = context.Background()
|
||||||
|
s.db = testTx(s.T())
|
||||||
|
s.repo = NewUserSubscriptionRepository(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(UserSubscriptionRepoSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Create / GetByID / Update / Delete ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
|
||||||
|
|
||||||
|
sub := &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.repo.Create(s.ctx, sub)
|
||||||
|
s.Require().NoError(err, "Create")
|
||||||
|
s.Require().NotZero(sub.ID, "expected ID to be set")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(sub.UserID, got.UserID)
|
||||||
|
s.Require().Equal(sub.GroupID, got.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
|
||||||
|
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
|
||||||
|
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
AssignedBy: &admin.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().NotNil(got.User, "expected User preload")
|
||||||
|
s.Require().NotNil(got.Group, "expected Group preload")
|
||||||
|
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
|
||||||
|
s.Require().Equal(user.ID, got.User.ID)
|
||||||
|
s.Require().Equal(group.ID, got.Group.ID)
|
||||||
|
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
|
||||||
|
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
sub.Notes = "updated notes"
|
||||||
|
err := s.repo.Update(s.ctx, sub)
|
||||||
|
s.Require().NoError(err, "Update")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after update")
|
||||||
|
s.Require().Equal("updated notes", got.Notes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.repo.Delete(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err, "Delete")
|
||||||
|
|
||||||
|
_, err = s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().Error(err, "expected error after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||||
|
s.Require().NoError(err, "GetByUserIDAndGroupID")
|
||||||
|
s.Require().Equal(sub.ID, got.ID)
|
||||||
|
s.Require().NotNil(got.Group, "expected Group preload")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
|
||||||
|
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
|
||||||
|
s.Require().Error(err, "expected error for non-existent pair")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
|
||||||
|
|
||||||
|
// Create active subscription (future expiry)
|
||||||
|
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||||
|
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||||
|
s.Require().Equal(active.ID, got.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
|
||||||
|
|
||||||
|
// Create expired subscription (past expiry but active status)
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||||
|
s.Require().Error(err, "expected error for expired subscription")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByUserID / ListActiveByUserID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
|
||||||
|
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
|
||||||
|
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g1.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g2.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "ListByUserID")
|
||||||
|
s.Require().Len(subs, 2)
|
||||||
|
for _, sub := range subs {
|
||||||
|
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
|
||||||
|
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
|
||||||
|
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g1.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g2.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
|
||||||
|
s.Require().NoError(err, "ListActiveByUserID")
|
||||||
|
s.Require().Len(subs, 1)
|
||||||
|
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListByGroupID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user1.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user2.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||||
|
s.Require().NoError(err, "ListByGroupID")
|
||||||
|
s.Require().Len(subs, 2)
|
||||||
|
s.Require().Equal(int64(2), page.Total)
|
||||||
|
for _, sub := range subs {
|
||||||
|
s.Require().NotNil(sub.User, "expected User preload")
|
||||||
|
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List with filters ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
|
||||||
|
s.Require().NoError(err, "List")
|
||||||
|
s.Require().Len(subs, 1)
|
||||||
|
s.Require().Equal(int64(1), page.Total)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user1.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user2.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(subs, 1)
|
||||||
|
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
|
||||||
|
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
|
||||||
|
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g1.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: g2.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(subs, 1)
|
||||||
|
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Len(subs, 1)
|
||||||
|
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Usage tracking ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
|
||||||
|
s.Require().NoError(err, "IncrementUsage")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(1.25, got.DailyUsageUSD)
|
||||||
|
s.Require().Equal(1.25, got.WeeklyUsageUSD)
|
||||||
|
s.Require().Equal(1.25, got.MonthlyUsageUSD)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
|
||||||
|
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(3.5, got.DailyUsageUSD)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
|
||||||
|
s.Require().NoError(err, "ActivateWindows")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(got.DailyWindowStart)
|
||||||
|
s.Require().NotNil(got.WeeklyWindowStart)
|
||||||
|
s.Require().NotNil(got.MonthlyWindowStart)
|
||||||
|
s.Require().True(got.DailyWindowStart.Equal(activateAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
DailyUsageUSD: 10.0,
|
||||||
|
WeeklyUsageUSD: 20.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||||
|
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
|
||||||
|
s.Require().NoError(err, "ResetDailyUsage")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(got.DailyUsageUSD)
|
||||||
|
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
|
||||||
|
s.Require().True(got.DailyWindowStart.Equal(resetAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
WeeklyUsageUSD: 15.0,
|
||||||
|
MonthlyUsageUSD: 30.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
|
||||||
|
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
|
||||||
|
s.Require().NoError(err, "ResetWeeklyUsage")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(got.WeeklyUsageUSD)
|
||||||
|
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
|
||||||
|
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
MonthlyUsageUSD: 100.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
|
||||||
|
s.Require().NoError(err, "ResetMonthlyUsage")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Zero(got.MonthlyUsageUSD)
|
||||||
|
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||||
|
s.Require().NoError(err, "UpdateStatus")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
|
||||||
|
s.Require().NoError(err, "ExtendExpiry")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().True(got.ExpiresAt.Equal(newExpiry))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
|
||||||
|
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
|
||||||
|
s.Require().NoError(err, "UpdateNotes")
|
||||||
|
|
||||||
|
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().Equal("VIP user", got.Notes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ListExpired / BatchUpdateExpiredStatus ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
expired, err := s.repo.ListExpired(s.ctx)
|
||||||
|
s.Require().NoError(err, "ListExpired")
|
||||||
|
s.Require().Len(expired, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
|
||||||
|
|
||||||
|
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||||
|
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||||
|
s.Require().Equal(int64(1), affected)
|
||||||
|
|
||||||
|
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
|
||||||
|
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
|
||||||
|
|
||||||
|
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||||
|
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ExistsByUserIDAndGroupID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||||
|
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
|
||||||
|
s.Require().True(exists)
|
||||||
|
|
||||||
|
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().False(notExists)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- CountByGroupID / CountActiveByGroupID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user1.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user2.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "CountByGroupID")
|
||||||
|
s.Require().Equal(int64(2), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||||
|
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
|
||||||
|
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user1.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user2.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
|
||||||
|
})
|
||||||
|
|
||||||
|
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "CountActiveByGroupID")
|
||||||
|
s.Require().Equal(int64(1), count, "only future expiry counts as active")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- DeleteByGroupID ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
|
||||||
|
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||||
|
})
|
||||||
|
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusExpired,
|
||||||
|
ExpiresAt: time.Now().Add(-24 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().NoError(err, "DeleteByGroupID")
|
||||||
|
s.Require().Equal(int64(2), affected)
|
||||||
|
|
||||||
|
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||||
|
s.Require().Zero(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Combined original test ---
|
||||||
|
|
||||||
|
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
|
||||||
|
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
|
||||||
|
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
|
||||||
|
|
||||||
|
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(2 * time.Hour),
|
||||||
|
})
|
||||||
|
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
Status: model.SubscriptionStatusActive,
|
||||||
|
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||||
|
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||||
|
s.Require().Equal(active.ID, got.ID, "expected active subscription")
|
||||||
|
|
||||||
|
activateAt := time.Now().Add(-25 * time.Hour)
|
||||||
|
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
|
||||||
|
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
|
||||||
|
|
||||||
|
after, err := s.repo.GetByID(s.ctx, active.ID)
|
||||||
|
s.Require().NoError(err, "GetByID")
|
||||||
|
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
|
||||||
|
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
|
||||||
|
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
|
||||||
|
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
|
||||||
|
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
|
||||||
|
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
|
||||||
|
|
||||||
|
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
|
||||||
|
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
|
||||||
|
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
|
||||||
|
s.Require().NoError(err, "GetByID after reset")
|
||||||
|
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
|
||||||
|
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
|
||||||
|
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
|
||||||
|
|
||||||
|
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||||
|
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||||
|
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||||
|
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||||
|
s.Require().NoError(err, "GetByID expired")
|
||||||
|
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user