diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index bee7db76..5e77f46e 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,9 +17,12 @@ jobs: go-version-file: backend/go.mod check-latest: true cache: true - - name: Run tests + - name: Unit tests working-directory: backend - run: go test ./... + run: make test-unit + - name: Integration tests + working-directory: backend + run: make test-integration golangci-lint: runs-on: ubuntu-latest diff --git a/backend/Makefile b/backend/Makefile index e59acc78..291e2fe9 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,4 +1,4 @@ -.PHONY: wire build build-embed +.PHONY: wire build build-embed test-unit test-integration test-cover-integration clean-coverage wire: @echo "生成 Wire 代码..." @@ -13,4 +13,21 @@ build: build-embed: @echo "构建后端(嵌入前端)..." @go build -tags embed -o bin/server ./cmd/server - @echo "构建完成: bin/server (with embedded frontend)" \ No newline at end of file + @echo "构建完成: bin/server (with embedded frontend)" + +test-unit: + @go test ./... $(TEST_ARGS) + +test-integration: + @go test -tags integration ./internal/repository -count=1 -race -parallel=8 + +test-cover-integration: + @echo "运行集成测试并生成覆盖率报告..." + @go test -tags=integration -cover -coverprofile=coverage.out -count=1 -race -parallel=8 ./internal/repository/... + @go tool cover -func=coverage.out | tail -1 + @go tool cover -html=coverage.out -o coverage.html + @echo "覆盖率报告已生成: coverage.html" + +clean-coverage: + @rm -f coverage.out coverage.html + @echo "覆盖率文件已清理" \ No newline at end of file diff --git a/backend/go.mod b/backend/go.mod index 1eee0a24..faf196b7 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -11,8 +11,11 @@ require ( github.com/google/wire v0.7.0 github.com/imroc/req/v3 v3.56.0 github.com/lib/pq v1.10.9 - github.com/redis/go-redis/v9 v9.3.0 + github.com/redis/go-redis/v9 v9.7.3 github.com/spf13/viper v1.18.2 + github.com/stretchr/testify v1.11.1 + github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 + github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 golang.org/x/crypto v0.44.0 @@ -24,52 +27,99 @@ require ( ) require ( + dario.cat/mergo v1.0.2 // indirect + github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/bytedance/sonic v1.9.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/containerd/errdefs v1.0.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/containerd/log v0.1.0 // indirect + github.com/containerd/platforms v0.2.1 // indirect + github.com/cpuguy83/dockercfg v0.3.2 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/distribution/reference v0.6.0 // indirect + github.com/docker/docker v28.5.1+incompatible // indirect + github.com/docker/go-connections v0.6.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/ebitengine/purego v0.8.4 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/subcommands v1.2.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/icholy/digest v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.4.3 // indirect + github.com/jackc/pgx/v5 v5.5.4 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.1 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect - github.com/magiconair/properties v1.8.7 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect + github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/go-archive v0.1.0 // indirect + github.com/moby/patternmatcher v0.6.0 // indirect + github.com/moby/sys/sequential v0.6.0 // indirect + github.com/moby/sys/user v0.4.0 // indirect + github.com/moby/sys/userns v0.1.0 // indirect + github.com/moby/term v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/morikuni/aec v1.0.0 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.56.0 // indirect github.com/refraction-networking/utls v1.8.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/shirou/gopsutil/v4 v4.25.6 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/sdk v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect @@ -79,6 +129,7 @@ require ( golang.org/x/sys v0.38.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/tools v0.38.0 // indirect - google.golang.org/protobuf v1.31.0 // indirect + google.golang.org/grpc v1.75.1 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index efbb96d0..ac083b5d 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,3 +1,11 @@ +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= +github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= +github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= +github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -7,17 +15,43 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= +github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= +github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= +github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= +github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM= +github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= +github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -28,6 +62,13 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -40,9 +81,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -54,6 +94,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= @@ -64,8 +106,10 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= -github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= +github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -85,36 +129,70 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= -github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= +github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= +github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= +github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ= +github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo= +github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= +github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc= +github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= +github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= +github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= +github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= +github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= +github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= +github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= +github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY= github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c= -github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= -github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= +github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= +github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -128,6 +206,8 @@ github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMV github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -135,10 +215,16 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= +github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= +github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ= +github.com/testcontainers/testcontainers-go/modules/redis v0.40.0 h1:OG4qwcxp2O0re7V7M9lY9w0v6wWgWf7j7rtkpAnGMd0= +github.com/testcontainers/testcontainers-go/modules/redis v0.40.0/go.mod h1:Bc+EDhKMo5zI5V5zdBkHiMVzeAXbtI4n5isS/nzf6zw= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -148,12 +234,36 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= +go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= @@ -173,8 +283,14 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= @@ -186,9 +302,15 @@ golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= +google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= +google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 h1:i8QOKZfYg6AbGVZzUAY3LrNWCKF8O6zFisU9Wl9RER4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ= +google.golang.org/grpc v1.75.1 h1:/ODCNEuf9VghjgO3rqLcfg8fiOP0nSluljWFlDxELLI= +google.golang.org/grpc v1.75.1/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -201,6 +323,6 @@ gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= -gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= +gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go new file mode 100644 index 00000000..ca6decb8 --- /dev/null +++ b/backend/internal/repository/account_repo_integration_test.go @@ -0,0 +1,580 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type AccountRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *AccountRepository +} + +func (s *AccountRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewAccountRepository(s.db) +} + +func TestAccountRepoSuite(t *testing.T) { + suite.Run(t, new(AccountRepoSuite)) +} + +// --- Create / GetByID / Update / Delete --- + +func (s *AccountRepoSuite) TestCreate() { + account := &model.Account{ + Name: "test-create", + Platform: model.PlatformAnthropic, + Type: model.AccountTypeOAuth, + Status: model.StatusActive, + } + + err := s.repo.Create(s.ctx, account) + s.Require().NoError(err, "Create") + s.Require().NotZero(account.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("test-create", got.Name) +} + +func (s *AccountRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *AccountRepoSuite) TestUpdate() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"}) + + account.Name = "updated" + err := s.repo.Update(s.ctx, account) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("updated", got.Name) +} + +func (s *AccountRepoSuite) TestDelete() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"}) + + err := s.repo.Delete(s.ctx, account.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, account.ID) + s.Require().Error(err, "expected error after delete") +} + +func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) + mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) + + err := s.repo.Delete(s.ctx, account.ID) + s.Require().NoError(err, "Delete should cascade remove bindings") + + var count int64 + s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count) + s.Require().Zero(count, "expected bindings to be removed") +} + +// --- List / ListWithFilters --- + +func (s *AccountRepoSuite) TestList() { + mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"}) + + accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List") + s.Require().Len(accounts, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *AccountRepoSuite) TestListWithFilters() { + tests := []struct { + name string + setup func(db *gorm.DB) + platform string + accType string + status string + search string + wantCount int + validate func(accounts []model.Account) + }{ + { + name: "filter_by_platform", + setup: func(db *gorm.DB) { + mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic}) + mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI}) + }, + platform: model.PlatformOpenAI, + wantCount: 1, + validate: func(accounts []model.Account) { + s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform) + }, + }, + { + name: "filter_by_type", + setup: func(db *gorm.DB) { + mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth}) + mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey}) + }, + accType: model.AccountTypeApiKey, + wantCount: 1, + validate: func(accounts []model.Account) { + s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type) + }, + }, + { + name: "filter_by_status", + setup: func(db *gorm.DB) { + mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive}) + mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled}) + }, + status: model.StatusDisabled, + wantCount: 1, + validate: func(accounts []model.Account) { + s.Require().Equal(model.StatusDisabled, accounts[0].Status) + }, + }, + { + name: "filter_by_search", + setup: func(db *gorm.DB) { + mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"}) + mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"}) + }, + search: "alpha", + wantCount: 1, + validate: func(accounts []model.Account) { + s.Require().Contains(accounts[0].Name, "alpha") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // 每个 case 重新获取隔离资源 + db := testTx(s.T()) + repo := NewAccountRepository(db) + ctx := context.Background() + + tt.setup(db) + + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) + s.Require().NoError(err) + s.Require().Len(accounts, tt.wantCount) + if tt.validate != nil { + tt.validate(accounts) + } + }) + } +} + +// --- ListByGroup / ListActive / ListByPlatform --- + +func (s *AccountRepoSuite) TestListByGroup() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) + acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive}) + acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive}) + mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2) + mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1) + + accounts, err := s.repo.ListByGroup(s.ctx, group.ID) + s.Require().NoError(err, "ListByGroup") + s.Require().Len(accounts, 2) + // Should be ordered by priority + s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)") +} + +func (s *AccountRepoSuite) TestListActive() { + mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled}) + + accounts, err := s.repo.ListActive(s.ctx) + s.Require().NoError(err, "ListActive") + s.Require().Len(accounts, 1) + s.Require().Equal("active1", accounts[0].Name) +} + +func (s *AccountRepoSuite) TestListByPlatform() { + mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) + + accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic) + s.Require().NoError(err, "ListByPlatform") + s.Require().Len(accounts, 1) + s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) +} + +// --- Preload and VirtualFields --- + +func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { + proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) + + account := mustCreateAccount(s.T(), s.db, &model.Account{ + Name: "acc1", + ProxyID: &proxy.ID, + }) + mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err, "GetByID") + s.Require().NotNil(got.Proxy, "expected Proxy preload") + s.Require().Equal(proxy.ID, got.Proxy.ID) + s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated") + s.Require().Equal(group.ID, got.GroupIDs[0]) + s.Require().Len(got.Groups, 1, "expected Groups to be populated") + s.Require().Equal(group.ID, got.Groups[0].ID) + + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") + s.Require().NoError(err, "ListWithFilters") + s.Require().Equal(int64(1), page.Total) + s.Require().Len(accounts, 1) + s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list") + s.Require().Equal(proxy.ID, accounts[0].Proxy.ID) + s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list") + s.Require().Equal(group.ID, accounts[0].GroupIDs[0]) +} + +// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups --- + +func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { + g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) + g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"}) + + s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup") + groups, err := s.repo.GetGroups(s.ctx, account.ID) + s.Require().NoError(err, "GetGroups") + s.Require().Len(groups, 1, "expected 1 group") + s.Require().Equal(g1.ID, groups[0].ID) + + s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup") + groups, err = s.repo.GetGroups(s.ctx, account.ID) + s.Require().NoError(err, "GetGroups after remove") + s.Require().Empty(groups, "expected 0 groups after remove") + + s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups") + groups, err = s.repo.GetGroups(s.ctx, account.ID) + s.Require().NoError(err, "GetGroups after bind") + s.Require().Len(groups, 2, "expected 2 groups after bind") +} + +func (s *AccountRepoSuite) TestBindGroups_EmptyList() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) + mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) + + s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty") + + groups, err := s.repo.GetGroups(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Empty(groups, "expected 0 groups after binding empty list") +} + +// --- Schedulable --- + +func (s *AccountRepoSuite) TestListSchedulable() { + now := time.Now() + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) + + okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) + mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) + + future := now.Add(10 * time.Minute) + overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) + mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) + + sched, err := s.repo.ListSchedulable(s.ctx) + s.Require().NoError(err, "ListSchedulable") + ids := idsOfAccounts(sched) + s.Require().Contains(ids, okAcc.ID) + s.Require().NotContains(ids, overloaded.ID) +} + +func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() { + now := time.Now() + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) + + okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) + mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) + + future := now.Add(10 * time.Minute) + overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) + mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) + + rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true}) + mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1) + s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited") + + s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError") + + sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "ListSchedulableByGroupID") + s.Require().Len(sched, 1, "expected only ok account schedulable") + s.Require().Equal(okAcc.ID, sched[0].ID) + + s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit") + sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit") + s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit") +} + +func (s *AccountRepoSuite) TestListSchedulableByPlatform() { + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) + + accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic) + s.Require().NoError(err) + s.Require().Len(accounts, 1) + s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) +} + +func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"}) + a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) + a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) + mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) + mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) + + accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic) + s.Require().NoError(err) + s.Require().Len(accounts, 1) + s.Require().Equal(a1.ID, accounts[0].ID) +} + +func (s *AccountRepoSuite) TestSetSchedulable() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true}) + + s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().False(got.Schedulable) +} + +// --- SetOverloaded / SetRateLimited / ClearRateLimit --- + +func (s *AccountRepoSuite) TestSetOverloaded() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"}) + until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) + + s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().NotNil(got.OverloadUntil) + s.Require().WithinDuration(until, *got.OverloadUntil, time.Second) +} + +func (s *AccountRepoSuite) TestSetRateLimited() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"}) + resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC) + + s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().NotNil(got.RateLimitedAt) + s.Require().NotNil(got.RateLimitResetAt) + s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second) +} + +func (s *AccountRepoSuite) TestClearRateLimit() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"}) + until := time.Now().Add(1 * time.Hour) + s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) + s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until)) + + s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Nil(got.RateLimitedAt) + s.Require().Nil(got.RateLimitResetAt) + s.Require().Nil(got.OverloadUntil) +} + +// --- UpdateLastUsed --- + +func (s *AccountRepoSuite) TestUpdateLastUsed() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"}) + s.Require().Nil(account.LastUsedAt) + + s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().NotNil(got.LastUsedAt) +} + +// --- SetError --- + +func (s *AccountRepoSuite) TestSetError() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive}) + + s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal(model.StatusError, got.Status) + s.Require().Equal("something went wrong", got.ErrorMessage) +} + +// --- UpdateSessionWindow --- + +func (s *AccountRepoSuite) TestUpdateSessionWindow() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"}) + start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) + end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC) + + s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active")) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().NotNil(got.SessionWindowStart) + s.Require().NotNil(got.SessionWindowEnd) + s.Require().Equal("active", got.SessionWindowStatus) +} + +// --- UpdateExtra --- + +func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { + account := mustCreateAccount(s.T(), s.db, &model.Account{ + Name: "acc-extra", + Extra: model.JSONB{"a": "1"}, + }) + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra") + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("1", got.Extra["a"]) + s.Require().Equal("2", got.Extra["b"]) +} + +func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"}) + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{})) +} + +func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil}) + s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"})) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal("val", got.Extra["key"]) +} + +// --- GetByCRSAccountID --- + +func (s *AccountRepoSuite) TestGetByCRSAccountID() { + crsID := "crs-12345" + mustCreateAccount(s.T(), s.db, &model.Account{ + Name: "acc-crs", + Extra: model.JSONB{"crs_account_id": crsID}, + }) + + got, err := s.repo.GetByCRSAccountID(s.ctx, crsID) + s.Require().NoError(err) + s.Require().NotNil(got) + s.Require().Equal("acc-crs", got.Name) +} + +func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() { + got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent") + s.Require().NoError(err) + s.Require().Nil(got) +} + +func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() { + got, err := s.repo.GetByCRSAccountID(s.ctx, "") + s.Require().NoError(err) + s.Require().Nil(got) +} + +// --- BulkUpdate --- + +func (s *AccountRepoSuite) TestBulkUpdate() { + a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1}) + a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1}) + + newPriority := 99 + affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{ + Priority: &newPriority, + }) + s.Require().NoError(err) + s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row") + + got1, _ := s.repo.GetByID(s.ctx, a1.ID) + got2, _ := s.repo.GetByID(s.ctx, a2.ID) + s.Require().Equal(99, got1.Priority) + s.Require().Equal(99, got2.Priority) +} + +func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { + a1 := mustCreateAccount(s.T(), s.db, &model.Account{ + Name: "bulk-cred", + Credentials: model.JSONB{"existing": "value"}, + }) + + _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{ + Credentials: model.JSONB{"new_key": "new_value"}, + }) + s.Require().NoError(err) + + got, _ := s.repo.GetByID(s.ctx, a1.ID) + s.Require().Equal("value", got.Credentials["existing"]) + s.Require().Equal("new_value", got.Credentials["new_key"]) +} + +func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() { + a1 := mustCreateAccount(s.T(), s.db, &model.Account{ + Name: "bulk-extra", + Extra: model.JSONB{"existing": "val"}, + }) + + _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{ + Extra: model.JSONB{"new_key": "new_val"}, + }) + s.Require().NoError(err) + + got, _ := s.repo.GetByID(s.ctx, a1.ID) + s.Require().Equal("val", got.Extra["existing"]) + s.Require().Equal("new_val", got.Extra["new_key"]) +} + +func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() { + affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{}) + s.Require().NoError(err) + s.Require().Zero(affected) +} + +func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { + a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"}) + + affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{}) + s.Require().NoError(err) + s.Require().Zero(affected) +} + +func idsOfAccounts(accounts []model.Account) []int64 { + out := make([]int64, 0, len(accounts)) + for i := range accounts { + out = append(out, accounts[i].ID) + } + return out +} diff --git a/backend/internal/repository/api_key_cache_integration_test.go b/backend/internal/repository/api_key_cache_integration_test.go new file mode 100644 index 00000000..6fcd0dfd --- /dev/null +++ b/backend/internal/repository/api_key_cache_integration_test.go @@ -0,0 +1,125 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ApiKeyCacheSuite struct { + IntegrationRedisSuite +} + +func (s *ApiKeyCacheSuite) TestCreateAttemptCount() { + tests := []struct { + name string + fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) + }{ + { + name: "missing_key_returns_redis_nil", + fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { + userID := int64(1) + + _, err := cache.GetCreateAttemptCount(ctx, userID) + + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key") + }, + }, + { + name: "increment_increases_count_and_sets_ttl", + fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { + userID := int64(1) + key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) + + require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount") + require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2") + + count, err := cache.GetCreateAttemptCount(ctx, userID) + require.NoError(s.T(), err, "GetCreateAttemptCount") + require.Equal(s.T(), 2, count, "count mismatch") + + ttl, err := rdb.TTL(ctx, key).Result() + require.NoError(s.T(), err, "TTL") + s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration) + }, + }, + { + name: "delete_removes_key", + fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { + userID := int64(1) + + require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID)) + require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount") + + _, err := cache.GetCreateAttemptCount(ctx, userID) + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + // 每个 case 重新获取隔离资源 + rdb := testRedis(s.T()) + cache := &apiKeyCache{rdb: rdb} + ctx := context.Background() + + tt.fn(ctx, rdb, cache) + }) + } +} + +func (s *ApiKeyCacheSuite) TestDailyUsage() { + tests := []struct { + name string + fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) + }{ + { + name: "increment_increases_count", + fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { + dailyKey := "daily:sk-test" + + require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage") + require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2") + + n, err := rdb.Get(ctx, dailyKey).Int() + require.NoError(s.T(), err, "Get dailyKey") + require.Equal(s.T(), 2, n, "expected daily usage=2") + }, + }, + { + name: "set_expiry_sets_ttl", + fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { + dailyKey := "daily:sk-test-expiry" + + require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey)) + require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry") + + ttl, err := rdb.TTL(ctx, dailyKey).Result() + require.NoError(s.T(), err, "TTL dailyKey") + require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := &apiKeyCache{rdb: rdb} + ctx := context.Background() + + tt.fn(ctx, rdb, cache) + }) + } +} + +func TestApiKeyCacheSuite(t *testing.T) { + suite.Run(t, new(ApiKeyCacheSuite)) +} diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go new file mode 100644 index 00000000..00b332f9 --- /dev/null +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -0,0 +1,355 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type ApiKeyRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *ApiKeyRepository +} + +func (s *ApiKeyRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewApiKeyRepository(s.db) +} + +func TestApiKeyRepoSuite(t *testing.T) { + suite.Run(t, new(ApiKeyRepoSuite)) +} + +// --- Create / GetByID / GetByKey --- + +func (s *ApiKeyRepoSuite) TestCreate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) + + key := &model.ApiKey{ + UserID: user.ID, + Key: "sk-create-test", + Name: "Test Key", + Status: model.StatusActive, + } + + err := s.repo.Create(s.ctx, key) + s.Require().NoError(err, "Create") + s.Require().NotZero(key.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("sk-create-test", got.Key) +} + +func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *ApiKeyRepoSuite) TestGetByKey() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"}) + + key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-getbykey", + Name: "My Key", + GroupID: &group.ID, + Status: model.StatusActive, + }) + + got, err := s.repo.GetByKey(s.ctx, key.Key) + s.Require().NoError(err, "GetByKey") + s.Require().Equal(key.ID, got.ID) + s.Require().NotNil(got.User, "expected User preload") + s.Require().Equal(user.ID, got.User.ID) + s.Require().NotNil(got.Group, "expected Group preload") + s.Require().Equal(group.ID, got.Group.ID) +} + +func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { + _, err := s.repo.GetByKey(s.ctx, "non-existent-key") + s.Require().Error(err, "expected error for non-existent key") +} + +// --- Update --- + +func (s *ApiKeyRepoSuite) TestUpdate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) + key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-update", + Name: "Original", + Status: model.StatusActive, + }) + + key.Name = "Renamed" + key.Status = model.StatusDisabled + err := s.repo.Update(s.ctx, key) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("sk-update", got.Key, "Update should not change key") + s.Require().Equal(user.ID, got.UserID, "Update should not change user_id") + s.Require().Equal("Renamed", got.Name) + s.Require().Equal(model.StatusDisabled, got.Status) +} + +func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"}) + key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-clear-group", + Name: "Group Key", + GroupID: &group.ID, + }) + + key.GroupID = nil + err := s.repo.Update(s.ctx, key) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err) + s.Require().Nil(got.GroupID, "expected GroupID to be cleared") +} + +// --- Delete --- + +func (s *ApiKeyRepoSuite) TestDelete() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) + key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-delete", + Name: "Delete Me", + }) + + err := s.repo.Delete(s.ctx, key.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, key.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- ListByUserID / CountByUserID --- + +func (s *ApiKeyRepoSuite) TestListByUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) + + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByUserID") + s.Require().Len(keys, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"}) + for i := 0; i < 5; i++ { + mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-page-" + string(rune('a'+i)), + Name: "Key", + }) + } + + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}) + s.Require().NoError(err) + s.Require().Len(keys, 2) + s.Require().Equal(int64(5), page.Total) + s.Require().Equal(3, page.Pages) +} + +func (s *ApiKeyRepoSuite) TestCountByUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) + + count, err := s.repo.CountByUserID(s.ctx, user.ID) + s.Require().NoError(err, "CountByUserID") + s.Require().Equal(int64(2), count) +} + +// --- ListByGroupID / CountByGroupID --- + +func (s *ApiKeyRepoSuite) TestListByGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) + + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group + + keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByGroupID") + s.Require().Len(keys, 2) + s.Require().Equal(int64(2), page.Total) + // User preloaded + s.Require().NotNil(keys[0].User) +} + +func (s *ApiKeyRepoSuite) TestCountByGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) + + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) + + count, err := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "CountByGroupID") + s.Require().Equal(int64(1), count) +} + +// --- ExistsByKey --- + +func (s *ApiKeyRepoSuite) TestExistsByKey() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"}) + + exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") + s.Require().NoError(err, "ExistsByKey") + s.Require().True(exists) + + notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists") + s.Require().NoError(err) + s.Require().False(notExists) +} + +// --- SearchApiKeys --- + +func (s *ApiKeyRepoSuite) TestSearchApiKeys() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) + + found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) + s.Require().NoError(err, "SearchApiKeys") + s.Require().Len(found, 1) + s.Require().Contains(found[0].Name, "Production") +} + +func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) + + found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) + s.Require().NoError(err) + s.Require().Len(found, 2) +} + +func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) + + found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) + s.Require().NoError(err) + s.Require().Len(found, 1) +} + +// --- ClearGroupIDByGroupID --- + +func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"}) + + k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) + k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group + + affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "ClearGroupIDByGroupID") + s.Require().Equal(int64(2), affected) + + got1, _ := s.repo.GetByID(s.ctx, k1.ID) + got2, _ := s.repo.GetByID(s.ctx, k2.ID) + s.Require().Nil(got1.GroupID) + s.Require().Nil(got2.GroupID) + + count, _ := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().Zero(count) +} + +// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- + +func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"}) + + key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-test-1", + Name: "My Key", + GroupID: &group.ID, + Status: model.StatusActive, + }) + + got, err := s.repo.GetByKey(s.ctx, key.Key) + s.Require().NoError(err, "GetByKey") + s.Require().Equal(key.ID, got.ID) + s.Require().NotNil(got.User) + s.Require().Equal(user.ID, got.User.ID) + s.Require().NotNil(got.Group) + s.Require().Equal(group.ID, got.Group.ID) + + key.Name = "Renamed" + key.Status = model.StatusDisabled + key.GroupID = nil + s.Require().NoError(s.repo.Update(s.ctx, key), "Update") + + got2, err := s.repo.GetByID(s.ctx, key.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("sk-test-1", got2.Key, "Update should not change key") + s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id") + s.Require().Equal("Renamed", got2.Name) + s.Require().Equal(model.StatusDisabled, got2.Status) + s.Require().Nil(got2.GroupID) + + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByUserID") + s.Require().Equal(int64(1), page.Total) + s.Require().Len(keys, 1) + + exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1") + s.Require().NoError(err, "ExistsByKey") + s.Require().True(exists, "expected key to exist") + + found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10) + s.Require().NoError(err, "SearchApiKeys") + s.Require().Len(found, 1) + s.Require().Equal(key.ID, found[0].ID) + + // ClearGroupIDByGroupID + k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + UserID: user.ID, + Key: "sk-test-2", + Name: "Group Key", + GroupID: &group.ID, + }) + + countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "CountByGroupID") + s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear") + + affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "ClearGroupIDByGroupID") + s.Require().Equal(int64(1), affected, "expected 1 affected row") + + got3, err := s.repo.GetByID(s.ctx, k2.ID) + s.Require().NoError(err, "GetByID") + s.Require().Nil(got3.GroupID, "expected GroupID cleared") + + countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "CountByGroupID after clear") + s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear") +} diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go new file mode 100644 index 00000000..893ae8d7 --- /dev/null +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -0,0 +1,283 @@ +//go:build integration + +package repository + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type BillingCacheSuite struct { + IntegrationRedisSuite +} + +func (s *BillingCacheSuite) TestUserBalance() { + tests := []struct { + name string + fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) + }{ + { + name: "missing_key_returns_redis_nil", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + _, err := cache.GetUserBalance(ctx, 1) + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key") + }, + }, + { + name: "deduct_on_nonexistent_is_noop", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(1) + balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + + require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error") + + _, err := rdb.Get(ctx, balanceKey).Result() + require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent") + }, + }, + { + name: "set_and_get_with_ttl", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(2) + balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + + require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance") + + got, err := cache.GetUserBalance(ctx, userID) + require.NoError(s.T(), err, "GetUserBalance") + require.Equal(s.T(), 10.5, got, "balance mismatch") + + ttl, err := rdb.TTL(ctx, balanceKey).Result() + require.NoError(s.T(), err, "TTL") + s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL) + }, + }, + { + name: "deduct_reduces_balance", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(3) + + require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance") + require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance") + + got, err := cache.GetUserBalance(ctx, userID) + require.NoError(s.T(), err, "GetUserBalance after deduct") + require.Equal(s.T(), 8.25, got, "deduct mismatch") + }, + }, + { + name: "invalidate_removes_key", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(100) + balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + + require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance") + + exists, err := rdb.Exists(ctx, balanceKey).Result() + require.NoError(s.T(), err, "Exists") + require.Equal(s.T(), int64(1), exists, "expected balance key to exist") + + require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance") + + exists, err = rdb.Exists(ctx, balanceKey).Result() + require.NoError(s.T(), err, "Exists after invalidate") + require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate") + + _, err = cache.GetUserBalance(ctx, userID) + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate") + }, + }, + { + name: "deduct_refreshes_ttl", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(103) + balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) + + require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance") + + ttl1, err := rdb.TTL(ctx, balanceKey).Result() + require.NoError(s.T(), err, "TTL before deduct") + s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL) + + require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance") + + balance, err := cache.GetUserBalance(ctx, userID) + require.NoError(s.T(), err, "GetUserBalance") + require.Equal(s.T(), 75.0, balance, "expected balance 75.0") + + ttl2, err := rdb.TTL(ctx, balanceKey).Result() + require.NoError(s.T(), err, "TTL after deduct") + s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL) + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + tt.fn(ctx, rdb, cache) + }) + } +} + +func (s *BillingCacheSuite) TestSubscriptionCache() { + tests := []struct { + name string + fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) + }{ + { + name: "missing_key_returns_redis_nil", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(10) + groupID := int64(20) + + _, err := cache.GetSubscriptionCache(ctx, userID, groupID) + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key") + }, + }, + { + name: "update_usage_on_nonexistent_is_noop", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(11) + groupID := int64(21) + subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + + require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error") + + exists, err := rdb.Exists(ctx, subKey).Result() + require.NoError(s.T(), err, "Exists") + require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent") + }, + }, + { + name: "set_and_get_with_ttl", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(12) + groupID := int64(22) + subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + + data := &ports.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + DailyUsage: 1.0, + WeeklyUsage: 2.0, + MonthlyUsage: 3.0, + Version: 7, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache") + + gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID) + require.NoError(s.T(), err, "GetSubscriptionCache") + require.Equal(s.T(), "active", gotSub.Status) + require.Equal(s.T(), int64(7), gotSub.Version) + require.Equal(s.T(), 1.0, gotSub.DailyUsage) + + ttl, err := rdb.TTL(ctx, subKey).Result() + require.NoError(s.T(), err, "TTL subKey") + s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL) + }, + }, + { + name: "update_usage_increments_all_fields", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(13) + groupID := int64(23) + + data := &ports.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + DailyUsage: 1.0, + WeeklyUsage: 2.0, + MonthlyUsage: 3.0, + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache") + + require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage") + + gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID) + require.NoError(s.T(), err, "GetSubscriptionCache after update") + require.Equal(s.T(), 1.5, gotSub.DailyUsage) + require.Equal(s.T(), 2.5, gotSub.WeeklyUsage) + require.Equal(s.T(), 3.5, gotSub.MonthlyUsage) + }, + }, + { + name: "invalidate_removes_key", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(101) + groupID := int64(10) + subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + + data := &ports.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + DailyUsage: 1.0, + WeeklyUsage: 2.0, + MonthlyUsage: 3.0, + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache") + + exists, err := rdb.Exists(ctx, subKey).Result() + require.NoError(s.T(), err, "Exists") + require.Equal(s.T(), int64(1), exists, "expected subscription key to exist") + + require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache") + + exists, err = rdb.Exists(ctx, subKey).Result() + require.NoError(s.T(), err, "Exists after invalidate") + require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate") + + _, err = cache.GetSubscriptionCache(ctx, userID, groupID) + require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate") + }, + }, + { + name: "missing_status_returns_parsing_error", + fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) { + userID := int64(102) + groupID := int64(11) + subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) + + fields := map[string]any{ + "expires_at": time.Now().Add(1 * time.Hour).Unix(), + "daily_usage": 1.0, + "weekly_usage": 2.0, + "monthly_usage": 3.0, + "version": 1, + } + require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet") + + _, err := cache.GetSubscriptionCache(ctx, userID, groupID) + require.Error(s.T(), err, "expected error for missing status field") + require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil") + require.Equal(s.T(), "invalid cache: missing status", err.Error()) + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + tt.fn(ctx, rdb, cache) + }) + } +} + +func TestBillingCacheSuite(t *testing.T) { + suite.Run(t, new(BillingCacheSuite)) +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 23dd3661..005b1679 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -16,20 +16,28 @@ import ( "github.com/imroc/req/v3" ) -type claudeOAuthService struct{} - func NewClaudeOAuthClient() service.ClaudeOAuthClient { - return &claudeOAuthService{} + return &claudeOAuthService{ + baseURL: "https://claude.ai", + tokenURL: oauth.TokenURL, + clientFactory: createReqClient, + } +} + +type claudeOAuthService struct { + baseURL string + tokenURL string + clientFactory func(proxyURL string) *req.Client } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := createReqClient(proxyURL) + client := s.clientFactory(proxyURL) var orgs []struct { UUID string `json:"uuid"` } - targetURL := "https://claude.ai/api/organizations" + targetURL := s.baseURL + "/api/organizations" log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) resp, err := client.R(). @@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := createReqClient(proxyURL) + client := s.clientFactory(proxyURL) - authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID) + authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) reqBody := map[string]any{ "response_type": "code", @@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe fullCode = authCode + "#" + responseState } - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20]) + log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) return fullCode, nil } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { - client := createReqClient(proxyURL) + client := s.clientFactory(proxyURL) // Parse code which may contain state in format "authCode#state" authCode := code @@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } reqBodyJSON, _ := json.Marshal(reqBody) - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL) + log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) var tokenResp oauth.TokenResponse @@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod SetHeader("Content-Type", "application/json"). SetBody(reqBody). SetSuccessResult(&tokenResp). - Post(oauth.TokenURL) + Post(s.tokenURL) if err != nil { log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) @@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := createReqClient(proxyURL) + client := s.clientFactory(proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro SetContext(ctx). SetFormDataFromValues(formData). SetSuccessResult(&tokenResp). - Post(oauth.TokenURL) + Post(s.tokenURL) if err != nil { return nil, fmt.Errorf("request failed: %w", err) @@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client { return client } + +func prefix(s string, n int) string { + if n <= 0 { + return "" + } + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go new file mode 100644 index 00000000..1d466f48 --- /dev/null +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -0,0 +1,343 @@ +package repository + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ClaudeOAuthServiceSuite struct { + suite.Suite + srv *httptest.Server + client *claudeOAuthService +} + +func (s *ClaudeOAuthServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +// requestCapture holds captured request data for assertions in the main goroutine. +type requestCapture struct { + path string + method string + cookies []*http.Cookie + body []byte + formValues url.Values + bodyJSON map[string]any + contentType string +} + +func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { + tests := []struct { + name string + handler http.HandlerFunc + wantErr bool + errContain string + wantUUID string + validate func(captured requestCapture) + }{ + { + name: "success", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[{"uuid":"org-1"}]`)) + }, + wantUUID: "org-1", + validate: func(captured requestCapture) { + require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path") + require.Len(s.T(), captured.cookies, 1, "expected 1 cookie") + require.Equal(s.T(), "sessionKey", captured.cookies[0].Name) + require.Equal(s.T(), "sess", captured.cookies[0].Value) + }, + }, + { + name: "non_200_returns_error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("unauthorized")) + }, + wantErr: true, + errContain: "401", + }, + { + name: "invalid_json_returns_error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("not-json")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + var captured requestCapture + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.path = r.URL.Path + captured.cookies = r.Cookies() + tt.handler(w, r) + })) + defer s.srv.Close() + + client, ok := NewClaudeOAuthClient().(*claudeOAuthService) + require.True(s.T(), ok, "type assertion failed") + s.client = client + s.client.baseURL = s.srv.URL + + got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") + + if tt.wantErr { + require.Error(s.T(), err) + if tt.errContain != "" { + require.ErrorContains(s.T(), err, tt.errContain) + } + return + } + + require.NoError(s.T(), err) + require.Equal(s.T(), tt.wantUUID, got) + if tt.validate != nil { + tt.validate(captured) + } + }) + } +} + +func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { + tests := []struct { + name string + handler http.HandlerFunc + wantErr bool + wantCode string + validate func(captured requestCapture) + }{ + { + name: "parses_redirect_uri", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE", + }) + }, + wantCode: "AUTH#STATE", + validate: func(captured requestCapture) { + require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path) + require.Equal(s.T(), http.MethodPost, captured.method, "expected POST") + require.Len(s.T(), captured.cookies, 1, "expected 1 cookie") + require.Equal(s.T(), "sess", captured.cookies[0].Value) + require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"]) + require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) + require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"]) + require.Equal(s.T(), "st", captured.bodyJSON["state"]) + }, + }, + { + name: "missing_code_returns_error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "redirect_uri": oauth.RedirectURI + "?state=STATE", // no code + }) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + var captured requestCapture + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.path = r.URL.Path + captured.method = r.Method + captured.cookies = r.Cookies() + captured.body, _ = io.ReadAll(r.Body) + _ = json.Unmarshal(captured.body, &captured.bodyJSON) + tt.handler(w, r) + })) + defer s.srv.Close() + + client, ok := NewClaudeOAuthClient().(*claudeOAuthService) + require.True(s.T(), ok, "type assertion failed") + s.client = client + s.client.baseURL = s.srv.URL + + code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "") + + if tt.wantErr { + require.Error(s.T(), err) + return + } + + require.NoError(s.T(), err) + require.Equal(s.T(), tt.wantCode, code) + if tt.validate != nil { + tt.validate(captured) + } + }) + } +} + +func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { + tests := []struct { + name string + handler http.HandlerFunc + code string + wantErr bool + wantResp *oauth.TokenResponse + validate func(captured requestCapture) + }{ + { + name: "sends_state_when_embedded", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "at", + TokenType: "bearer", + ExpiresIn: 3600, + RefreshToken: "rt", + Scope: "s", + }) + }, + code: "AUTH#STATE2", + wantResp: &oauth.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + }, + validate: func(captured requestCapture) { + require.Equal(s.T(), http.MethodPost, captured.method, "expected POST") + require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type") + require.Equal(s.T(), "AUTH", captured.bodyJSON["code"]) + require.Equal(s.T(), "STATE2", captured.bodyJSON["state"]) + require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) + require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"]) + require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"]) + }, + }, + { + name: "non_200_returns_error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("bad request")) + }, + code: "AUTH", + wantErr: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + var captured requestCapture + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.contentType = r.Header.Get("Content-Type") + captured.body, _ = io.ReadAll(r.Body) + _ = json.Unmarshal(captured.body, &captured.bodyJSON) + tt.handler(w, r) + })) + defer s.srv.Close() + + client, ok := NewClaudeOAuthClient().(*claudeOAuthService) + require.True(s.T(), ok, "type assertion failed") + s.client = client + s.client.tokenURL = s.srv.URL + + resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "") + + if tt.wantErr { + require.Error(s.T(), err) + return + } + + require.NoError(s.T(), err) + require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken) + require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken) + if tt.validate != nil { + tt.validate(captured) + } + }) + } +} + +func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { + tests := []struct { + name string + handler http.HandlerFunc + wantErr bool + wantResp *oauth.TokenResponse + validate func(captured requestCapture) + }{ + { + name: "sends_form", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600}) + }, + wantResp: &oauth.TokenResponse{AccessToken: "at2"}, + validate: func(captured requestCapture) { + require.Equal(s.T(), http.MethodPost, captured.method, "expected POST") + require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type")) + require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token")) + require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id")) + }, + }, + { + name: "non_200_returns_error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("unauthorized")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + var captured requestCapture + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.body, _ = io.ReadAll(r.Body) + captured.formValues, _ = url.ParseQuery(string(captured.body)) + tt.handler(w, r) + })) + defer s.srv.Close() + + client, ok := NewClaudeOAuthClient().(*claudeOAuthService) + require.True(s.T(), ok, "type assertion failed") + s.client = client + s.client.tokenURL = s.srv.URL + + resp, err := s.client.RefreshToken(context.Background(), "rt", "") + + if tt.wantErr { + require.Error(s.T(), err) + return + } + + require.NoError(s.T(), err) + require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken) + if tt.validate != nil { + tt.validate(captured) + } + }) + } +} + +func TestClaudeOAuthServiceSuite(t *testing.T) { + suite.Run(t, new(ClaudeOAuthServiceSuite)) +} diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 9d7963bd..7ccbeafc 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -12,10 +12,14 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) -type claudeUsageService struct{} +const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" + +type claudeUsageService struct { + usageURL string +} func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { - return &claudeUsageService{} + return &claudeUsageService{usageURL: defaultClaudeUsageURL} } func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { @@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU Timeout: 30 * time.Second, } - req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil) if err != nil { return nil, fmt.Errorf("create request failed: %w", err) } diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go new file mode 100644 index 00000000..11097b67 --- /dev/null +++ b/backend/internal/repository/claude_usage_service_test.go @@ -0,0 +1,105 @@ +package repository + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ClaudeUsageServiceSuite struct { + suite.Suite + srv *httptest.Server + fetcher *claudeUsageService +} + +func (s *ClaudeUsageServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +// usageRequestCapture holds captured request data for assertions in the main goroutine. +type usageRequestCapture struct { + authorization string + anthropicBeta string +} + +func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { + var captured usageRequestCapture + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.authorization = r.Header.Get("Authorization") + captured.anthropicBeta = r.Header.Get("anthropic-beta") + + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"}, + "seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"}, + "seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"} +}`) + })) + + s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.NoError(s.T(), err, "FetchUsage") + require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") + require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") + require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch") + + // Assertions on captured request data + require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch") + require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch") +} + +func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, "nope") + })) + + s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "status 401") + require.ErrorContains(s.T(), err, "nope") +} + +func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + })) + + s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "decode response failed") +} + +func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Never respond - simulate slow server + <-r.Context().Done() + })) + + s.fetcher = &claudeUsageService{usageURL: s.srv.URL} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := s.fetcher.FetchUsage(ctx, "at", "") + require.Error(s.T(), err, "expected error for cancelled context") +} + +func TestClaudeUsageServiceSuite(t *testing.T) { + suite.Run(t, new(ClaudeUsageServiceSuite)) +} diff --git a/backend/internal/repository/concurrency_cache_integration_test.go b/backend/internal/repository/concurrency_cache_integration_test.go new file mode 100644 index 00000000..dc27dc9c --- /dev/null +++ b/backend/internal/repository/concurrency_cache_integration_test.go @@ -0,0 +1,231 @@ +//go:build integration + +package repository + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ConcurrencyCacheSuite struct { + IntegrationRedisSuite + cache ports.ConcurrencyCache +} + +func (s *ConcurrencyCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewConcurrencyCache(s.rdb) +} + +func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() { + accountID := int64(10) + reqID1, reqID2, reqID3 := "req1", "req2", "req3" + + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1) + require.NoError(s.T(), err, "AcquireAccountSlot 1") + require.True(s.T(), ok) + + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2) + require.NoError(s.T(), err, "AcquireAccountSlot 2") + require.True(s.T(), ok) + + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3) + require.NoError(s.T(), err, "AcquireAccountSlot 3") + require.False(s.T(), ok, "expected third acquire to fail") + + cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err, "GetAccountConcurrency") + require.Equal(s.T(), 2, cur, "concurrency mismatch") + + require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot") + + cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err, "GetAccountConcurrency after release") + require.Equal(s.T(), 1, cur, "expected 1 after release") +} + +func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() { + accountID := int64(11) + reqID := "req_ttl_test" + slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID) + + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID) + require.NoError(s.T(), err, "AcquireAccountSlot") + require.True(s.T(), ok) + + ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() + require.NoError(s.T(), err, "TTL") + s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) +} + +func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() { + accountID := int64(12) + reqID := "dup-req" + + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID) + require.NoError(s.T(), err) + require.True(s.T(), ok) + + // Acquiring with same reqID should be idempotent + ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID) + require.NoError(s.T(), err) + require.True(s.T(), ok) + + cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err) + require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)") +} + +func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() { + accountID := int64(13) + reqID := "release-test" + + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID) + require.NoError(s.T(), err) + require.True(s.T(), ok) + + require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot") + // Releasing again should not error + require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again") + // Releasing non-existent should not error + require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent") + + cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID) + require.NoError(s.T(), err) + require.Equal(s.T(), 0, cur) +} + +func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() { + accountID := int64(14) + reqID := "max-zero-test" + + ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID) + require.NoError(s.T(), err) + require.False(s.T(), ok, "expected acquire to fail with max=0") +} + +func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() { + userID := int64(42) + reqID1, reqID2 := "req1", "req2" + + ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1) + require.NoError(s.T(), err, "AcquireUserSlot") + require.True(s.T(), ok) + + ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2) + require.NoError(s.T(), err, "AcquireUserSlot 2") + require.False(s.T(), ok, "expected second acquire to fail at max=1") + + cur, err := s.cache.GetUserConcurrency(s.ctx, userID) + require.NoError(s.T(), err, "GetUserConcurrency") + require.Equal(s.T(), 1, cur, "expected concurrency=1") + + require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot") + // Releasing a non-existent slot should not error + require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent") + + cur, err = s.cache.GetUserConcurrency(s.ctx, userID) + require.NoError(s.T(), err, "GetUserConcurrency after release") + require.Equal(s.T(), 0, cur, "expected concurrency=0 after release") +} + +func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() { + userID := int64(200) + reqID := "req_ttl_test" + slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID) + + ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID) + require.NoError(s.T(), err, "AcquireUserSlot") + require.True(s.T(), ok) + + ttl, err := s.rdb.TTL(s.ctx, slotKey).Result() + require.NoError(s.T(), err, "TTL") + s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) +} + +func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() { + userID := int64(20) + waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + + ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2) + require.NoError(s.T(), err, "IncrementWaitCount 1") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2) + require.NoError(s.T(), err, "IncrementWaitCount 2") + require.True(s.T(), ok) + + ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2) + require.NoError(s.T(), err, "IncrementWaitCount 3") + require.False(s.T(), ok, "expected wait increment over max to fail") + + ttl, err := s.rdb.TTL(s.ctx, waitKey).Result() + require.NoError(s.T(), err, "TTL waitKey") + s.AssertTTLWithin(ttl, 1*time.Second, slotTTL) + + require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") + + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.Equal(s.T(), 1, val, "expected wait count 1") +} + +func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() { + userID := int64(300) + waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) + + // Test decrement on non-existent key - should not error and should not create negative value + require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key") + + // Verify no key was created or it's not negative + val, err := s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey") + } + require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty") + + // Set count to 1, then decrement twice + ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5) + require.NoError(s.T(), err, "IncrementWaitCount") + require.True(s.T(), ok) + + // Decrement once (1 -> 0) + require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount") + + // Decrement again on 0 - should not go negative + require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero") + + // Verify count is 0, not negative + val, err = s.rdb.Get(s.ctx, waitKey).Int() + if !errors.Is(err, redis.Nil) { + require.NoError(s.T(), err, "Get waitKey after double decrement") + } + require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count") +} + +func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() { + // When no slots exist, GetAccountConcurrency should return 0 + cur, err := s.cache.GetAccountConcurrency(s.ctx, 999) + require.NoError(s.T(), err) + require.Equal(s.T(), 0, cur) +} + +func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() { + // When no slots exist, GetUserConcurrency should return 0 + cur, err := s.cache.GetUserConcurrency(s.ctx, 999) + require.NoError(s.T(), err) + require.Equal(s.T(), 0, cur) +} + +func TestConcurrencyCacheSuite(t *testing.T) { + suite.Run(t, new(ConcurrencyCacheSuite)) +} diff --git a/backend/internal/repository/email_cache_integration_test.go b/backend/internal/repository/email_cache_integration_test.go new file mode 100644 index 00000000..22ce3f5e --- /dev/null +++ b/backend/internal/repository/email_cache_integration_test.go @@ -0,0 +1,92 @@ +//go:build integration + +package repository + +import ( + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type EmailCacheSuite struct { + IntegrationRedisSuite + cache ports.EmailCache +} + +func (s *EmailCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewEmailCache(s.rdb) +} + +func (s *EmailCacheSuite) TestGetVerificationCode_Missing() { + _, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com") + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code") +} + +func (s *EmailCacheSuite) TestSetAndGetVerificationCode() { + email := "a@example.com" + emailTTL := 2 * time.Minute + data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()} + + require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode") + + got, err := s.cache.GetVerificationCode(s.ctx, email) + require.NoError(s.T(), err, "GetVerificationCode") + require.Equal(s.T(), "123456", got.Code) + require.Equal(s.T(), 1, got.Attempts) +} + +func (s *EmailCacheSuite) TestVerificationCode_TTL() { + email := "ttl@example.com" + emailTTL := 2 * time.Minute + data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()} + + require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode") + + emailKey := verifyCodeKeyPrefix + email + ttl, err := s.rdb.TTL(s.ctx, emailKey).Result() + require.NoError(s.T(), err, "TTL emailKey") + s.AssertTTLWithin(ttl, 1*time.Second, emailTTL) +} + +func (s *EmailCacheSuite) TestDeleteVerificationCode() { + email := "delete@example.com" + data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()} + + require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode") + + // Verify it exists + _, err := s.cache.GetVerificationCode(s.ctx, email) + require.NoError(s.T(), err, "GetVerificationCode before delete") + + // Delete + require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode") + + // Verify it's gone + _, err = s.cache.GetVerificationCode(s.ctx, email) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete") +} + +func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() { + // Deleting a non-existent key should not error + require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent") +} + +func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() { + emailKey := verifyCodeKeyPrefix + "corrupted@example.com" + + require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON") + + _, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com") + require.Error(s.T(), err, "expected error for corrupted JSON") + require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil") +} + +func TestEmailCacheSuite(t *testing.T) { + suite.Run(t, new(EmailCacheSuite)) +} diff --git a/backend/internal/repository/fixtures_integration_test.go b/backend/internal/repository/fixtures_integration_test.go new file mode 100644 index 00000000..adeb8ac6 --- /dev/null +++ b/backend/internal/repository/fixtures_integration_test.go @@ -0,0 +1,172 @@ +//go:build integration + +package repository + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { + t.Helper() + if u.PasswordHash == "" { + u.PasswordHash = "test-password-hash" + } + if u.Role == "" { + u.Role = model.RoleUser + } + if u.Status == "" { + u.Status = model.StatusActive + } + if u.CreatedAt.IsZero() { + u.CreatedAt = time.Now() + } + if u.UpdatedAt.IsZero() { + u.UpdatedAt = u.CreatedAt + } + require.NoError(t, db.Create(u).Error, "create user") + return u +} + +func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { + t.Helper() + if g.Platform == "" { + g.Platform = model.PlatformAnthropic + } + if g.Status == "" { + g.Status = model.StatusActive + } + if g.SubscriptionType == "" { + g.SubscriptionType = model.SubscriptionTypeStandard + } + if g.CreatedAt.IsZero() { + g.CreatedAt = time.Now() + } + if g.UpdatedAt.IsZero() { + g.UpdatedAt = g.CreatedAt + } + require.NoError(t, db.Create(g).Error, "create group") + return g +} + +func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { + t.Helper() + if p.Protocol == "" { + p.Protocol = "http" + } + if p.Host == "" { + p.Host = "127.0.0.1" + } + if p.Port == 0 { + p.Port = 8080 + } + if p.Status == "" { + p.Status = model.StatusActive + } + if p.CreatedAt.IsZero() { + p.CreatedAt = time.Now() + } + if p.UpdatedAt.IsZero() { + p.UpdatedAt = p.CreatedAt + } + require.NoError(t, db.Create(p).Error, "create proxy") + return p +} + +func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account { + t.Helper() + if a.Platform == "" { + a.Platform = model.PlatformAnthropic + } + if a.Type == "" { + a.Type = model.AccountTypeOAuth + } + if a.Status == "" { + a.Status = model.StatusActive + } + if !a.Schedulable { + a.Schedulable = true + } + if a.Credentials == nil { + a.Credentials = model.JSONB{} + } + if a.Extra == nil { + a.Extra = model.JSONB{} + } + if a.CreatedAt.IsZero() { + a.CreatedAt = time.Now() + } + if a.UpdatedAt.IsZero() { + a.UpdatedAt = a.CreatedAt + } + require.NoError(t, db.Create(a).Error, "create account") + return a +} + +func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey { + t.Helper() + if k.Status == "" { + k.Status = model.StatusActive + } + if k.CreatedAt.IsZero() { + k.CreatedAt = time.Now() + } + if k.UpdatedAt.IsZero() { + k.UpdatedAt = k.CreatedAt + } + require.NoError(t, db.Create(k).Error, "create api key") + return k +} + +func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode { + t.Helper() + if c.Status == "" { + c.Status = model.StatusUnused + } + if c.Type == "" { + c.Type = model.RedeemTypeBalance + } + if c.CreatedAt.IsZero() { + c.CreatedAt = time.Now() + } + require.NoError(t, db.Create(c).Error, "create redeem code") + return c +} + +func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription { + t.Helper() + if s.Status == "" { + s.Status = model.SubscriptionStatusActive + } + now := time.Now() + if s.StartsAt.IsZero() { + s.StartsAt = now.Add(-1 * time.Hour) + } + if s.ExpiresAt.IsZero() { + s.ExpiresAt = now.Add(24 * time.Hour) + } + if s.AssignedAt.IsZero() { + s.AssignedAt = now + } + if s.CreatedAt.IsZero() { + s.CreatedAt = now + } + if s.UpdatedAt.IsZero() { + s.UpdatedAt = now + } + require.NoError(t, db.Create(s).Error, "create user subscription") + return s +} + +func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { + t.Helper() + require.NoError(t, db.Create(&model.AccountGroup{ + AccountID: accountID, + GroupID: groupID, + Priority: priority, + }).Error, "create account_group") +} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go new file mode 100644 index 00000000..5afe30fa --- /dev/null +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -0,0 +1,92 @@ +//go:build integration + +package repository + +import ( + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type GatewayCacheSuite struct { + IntegrationRedisSuite + cache ports.GatewayCache +} + +func (s *GatewayCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewGatewayCache(s.rdb) +} + +func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() { + _, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent") + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session") +} + +func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { + sessionID := "s1" + accountID := int64(99) + sessionTTL := 1 * time.Minute + + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") + + sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID) + require.NoError(s.T(), err, "GetSessionAccountID") + require.Equal(s.T(), accountID, sid, "session id mismatch") +} + +func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { + sessionID := "s2" + accountID := int64(100) + sessionTTL := 1 * time.Minute + + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") + + sessionKey := stickySessionPrefix + sessionID + ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() + require.NoError(s.T(), err, "TTL sessionKey after Set") + s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL) +} + +func (s *GatewayCacheSuite) TestRefreshSessionTTL() { + sessionID := "s3" + accountID := int64(101) + initialTTL := 1 * time.Minute + refreshTTL := 3 * time.Minute + + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID") + + require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL") + + sessionKey := stickySessionPrefix + sessionID + ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() + require.NoError(s.T(), err, "TTL after Refresh") + s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL) +} + +func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { + // RefreshSessionTTL on a missing key should not error (no-op) + err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute) + require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") +} + +func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { + sessionID := "corrupted" + sessionKey := stickySessionPrefix + sessionID + + // Set a non-integer value + require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value") + + _, err := s.cache.GetSessionAccountID(s.ctx, sessionID) + require.Error(s.T(), err, "expected error for corrupted value") + require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") +} + +func TestGatewayCacheSuite(t *testing.T) { + suite.Run(t, new(GatewayCacheSuite)) +} diff --git a/backend/internal/repository/github_release_service_test.go b/backend/internal/repository/github_release_service_test.go new file mode 100644 index 00000000..bf2efd8d --- /dev/null +++ b/backend/internal/repository/github_release_service_test.go @@ -0,0 +1,328 @@ +package repository + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type GitHubReleaseServiceSuite struct { + suite.Suite + srv *httptest.Server + client *githubReleaseClient + tempDir string +} + +// testTransport redirects requests to the test server +type testTransport struct { + testServerURL string +} + +func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the URL to point to our test server + testURL := t.testServerURL + req.URL.Path + newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body) + if err != nil { + return nil, err + } + newReq.Header = req.Header + return http.DefaultTransport.RoundTrip(newReq) +} + +func (s *GitHubReleaseServiceSuite) SetupTest() { + s.tempDir = s.T().TempDir() +} + +func (s *GitHubReleaseServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", "100") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(bytes.Repeat([]byte("a"), 100)) + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + dest := filepath.Join(s.tempDir, "file1.bin") + err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) + require.Error(s.T(), err, "expected error for oversized download with Content-Length") + + _, statErr := os.Stat(dest) + require.Error(s.T(), statErr, "expected file to not exist for rejected download") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Force chunked encoding (unknown Content-Length) by flushing headers before writing. + w.WriteHeader(http.StatusOK) + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + for i := 0; i < 10; i++ { + _, _ = w.Write(bytes.Repeat([]byte("b"), 10)) + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + } + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + dest := filepath.Join(s.tempDir, "file2.bin") + err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10) + require.Error(s.T(), err, "expected error for oversized chunked download") + + _, statErr := os.Stat(dest) + require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + for i := 0; i < 10; i++ { + _, _ = w.Write(bytes.Repeat([]byte("b"), 10)) + if fl, ok := w.(http.Flusher); ok { + fl.Flush() + } + } + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + dest := filepath.Join(s.tempDir, "file3.bin") + err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200) + require.NoError(s.T(), err, "expected success") + + b, err := os.ReadFile(dest) + require.NoError(s.T(), err, "read") + require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'") + require.Len(s.T(), b, 100, "downloaded content length mismatch") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + dest := filepath.Join(s.tempDir, "notfound.bin") + err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100) + require.Error(s.T(), err, "expected error for 404") + + _, statErr := os.Stat(dest) + require.Error(s.T(), statErr, "expected file to not exist for 404") +} + +func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("sum")) + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) + require.NoError(s.T(), err, "FetchChecksumFile") + require.Equal(s.T(), "sum", string(body), "checksum body mismatch") +} + +func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + _, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL) + require.Error(s.T(), err, "expected error for non-200") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + dest := filepath.Join(s.tempDir, "cancelled.bin") + err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100) + require.Error(s.T(), err, "expected error for cancelled context") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() { + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + dest := filepath.Join(s.tempDir, "invalid.bin") + err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100) + require.Error(s.T(), err, "expected error for invalid URL") +} + +func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("content")) + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + // Use a path that cannot be created (directory doesn't exist) + dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin") + err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100) + require.Error(s.T(), err, "expected error for invalid destination path") +} + +func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() { + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + _, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url") + require.Error(s.T(), err, "expected error for invalid URL") +} + +func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() { + releaseJSON := `{ + "tag_name": "v1.0.0", + "name": "Release 1.0.0", + "body": "Release notes", + "html_url": "https://github.com/test/repo/releases/v1.0.0", + "assets": [ + { + "name": "app-linux-amd64.tar.gz", + "browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz" + } + ] + }` + + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path) + require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept")) + require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent")) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(releaseJSON)) + })) + + // Use custom transport to redirect requests to test server + s.client = &githubReleaseClient{ + httpClient: &http.Client{ + Transport: &testTransport{testServerURL: s.srv.URL}, + }, + } + + release, err := s.client.FetchLatestRelease(context.Background(), "test/repo") + require.NoError(s.T(), err) + require.Equal(s.T(), "v1.0.0", release.TagName) + require.Equal(s.T(), "Release 1.0.0", release.Name) + require.Len(s.T(), release.Assets, 1) + require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name) +} + +func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + + s.client = &githubReleaseClient{ + httpClient: &http.Client{ + Transport: &testTransport{testServerURL: s.srv.URL}, + }, + } + + _, err := s.client.FetchLatestRelease(context.Background(), "test/repo") + require.Error(s.T(), err) + require.Contains(s.T(), err.Error(), "404") +} + +func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not valid json")) + })) + + s.client = &githubReleaseClient{ + httpClient: &http.Client{ + Transport: &testTransport{testServerURL: s.srv.URL}, + }, + } + + _, err := s.client.FetchLatestRelease(context.Background(), "test/repo") + require.Error(s.T(), err) +} + +func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + + s.client = &githubReleaseClient{ + httpClient: &http.Client{ + Transport: &testTransport{testServerURL: s.srv.URL}, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := s.client.FetchLatestRelease(ctx, "test/repo") + require.Error(s.T(), err) +} + +func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() { + s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + + client, ok := NewGitHubReleaseClient().(*githubReleaseClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := s.client.FetchChecksumFile(ctx, s.srv.URL) + require.Error(s.T(), err) +} + +func TestGitHubReleaseServiceSuite(t *testing.T) { + suite.Run(t, new(GitHubReleaseServiceSuite)) +} diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go new file mode 100644 index 00000000..e4464657 --- /dev/null +++ b/backend/internal/repository/group_repo_integration_test.go @@ -0,0 +1,244 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type GroupRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *GroupRepository +} + +func (s *GroupRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewGroupRepository(s.db) +} + +func TestGroupRepoSuite(t *testing.T) { + suite.Run(t, new(GroupRepoSuite)) +} + +// --- Create / GetByID / Update / Delete --- + +func (s *GroupRepoSuite) TestCreate() { + group := &model.Group{ + Name: "test-create", + Platform: model.PlatformAnthropic, + Status: model.StatusActive, + } + + err := s.repo.Create(s.ctx, group) + s.Require().NoError(err, "Create") + s.Require().NotZero(group.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, group.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("test-create", got.Name) +} + +func (s *GroupRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *GroupRepoSuite) TestUpdate() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"}) + + group.Name = "updated" + err := s.repo.Update(s.ctx, group) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, group.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("updated", got.Name) +} + +func (s *GroupRepoSuite) TestDelete() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"}) + + err := s.repo.Delete(s.ctx, group.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, group.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- List / ListWithFilters --- + +func (s *GroupRepoSuite) TestList() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) + + groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List") + s.Require().Len(groups, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *GroupRepoSuite) TestListWithFilters_Platform() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI}) + + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil) + s.Require().NoError(err) + s.Require().Len(groups, 1) + s.Require().Equal(model.PlatformOpenAI, groups[0].Platform) +} + +func (s *GroupRepoSuite) TestListWithFilters_Status() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled}) + + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil) + s.Require().NoError(err) + s.Require().Len(groups, 1) + s.Require().Equal(model.StatusDisabled, groups[0].Status) +} + +func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true}) + + isExclusive := true + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) + s.Require().NoError(err) + s.Require().Len(groups, 1) + s.Require().True(groups[0].IsExclusive) +} + +func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { + g1 := mustCreateGroup(s.T(), s.db, &model.Group{ + Name: "g1", + Platform: model.PlatformAnthropic, + Status: model.StatusActive, + }) + g2 := mustCreateGroup(s.T(), s.db, &model.Group{ + Name: "g2", + Platform: model.PlatformAnthropic, + Status: model.StatusActive, + IsExclusive: true, + }) + + a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) + mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1) + mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1) + + isExclusive := true + groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive) + s.Require().NoError(err, "ListWithFilters") + s.Require().Equal(int64(1), page.Total) + s.Require().Len(groups, 1) + s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group") + s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch") +} + +// --- ListActive / ListActiveByPlatform --- + +func (s *GroupRepoSuite) TestListActive() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled}) + + groups, err := s.repo.ListActive(s.ctx) + s.Require().NoError(err, "ListActive") + s.Require().Len(groups, 1) + s.Require().Equal("active1", groups[0].Name) +} + +func (s *GroupRepoSuite) TestListActiveByPlatform() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) + mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled}) + + groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic) + s.Require().NoError(err, "ListActiveByPlatform") + s.Require().Len(groups, 1) + s.Require().Equal("g1", groups[0].Name) +} + +// --- ExistsByName --- + +func (s *GroupRepoSuite) TestExistsByName() { + mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"}) + + exists, err := s.repo.ExistsByName(s.ctx, "existing-group") + s.Require().NoError(err, "ExistsByName") + s.Require().True(exists) + + notExists, err := s.repo.ExistsByName(s.ctx, "non-existing") + s.Require().NoError(err) + s.Require().False(notExists) +} + +// --- GetAccountCount --- + +func (s *GroupRepoSuite) TestGetAccountCount() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) + a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) + a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) + mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) + mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) + + count, err := s.repo.GetAccountCount(s.ctx, group.ID) + s.Require().NoError(err, "GetAccountCount") + s.Require().Equal(int64(2), count) +} + +func (s *GroupRepoSuite) TestGetAccountCount_Empty() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) + + count, err := s.repo.GetAccountCount(s.ctx, group.ID) + s.Require().NoError(err) + s.Require().Zero(count) +} + +// --- DeleteAccountGroupsByGroupID --- + +func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { + g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) + a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) + mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1) + + affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) + s.Require().NoError(err, "DeleteAccountGroupsByGroupID") + s.Require().Equal(int64(1), affected, "expected 1 affected row") + + count, err := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().NoError(err, "GetAccountCount") + s.Require().Equal(int64(0), count, "expected 0 account groups") +} + +func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { + g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"}) + a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) + a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) + a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) + mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1) + mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2) + mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3) + + affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) + s.Require().NoError(err) + s.Require().Equal(int64(3), affected) + + count, _ := s.repo.GetAccountCount(s.ctx, g.ID) + s.Require().Zero(count) +} + +// --- DB --- + +func (s *GroupRepoSuite) TestDB() { + db := s.repo.DB() + s.Require().NotNil(db, "DB should return non-nil") + s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB") +} diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go new file mode 100644 index 00000000..9bc38dae --- /dev/null +++ b/backend/internal/repository/http_upstream_test.go @@ -0,0 +1,115 @@ +package repository + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type HTTPUpstreamSuite struct { + suite.Suite + cfg *config.Config +} + +func (s *HTTPUpstreamSuite) SetupTest() { + s.cfg = &config.Config{} +} + +func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { + up := NewHTTPUpstream(s.cfg) + svc, ok := up.(*httpUpstreamService) + require.True(s.T(), ok, "expected *httpUpstreamService") + transport, ok := svc.defaultClient.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") +} + +func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { + s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} + up := NewHTTPUpstream(s.cfg) + svc, ok := up.(*httpUpstreamService) + require.True(s.T(), ok, "expected *httpUpstreamService") + transport, ok := svc.defaultClient.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") +} + +func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() { + s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5} + up := NewHTTPUpstream(s.cfg) + svc, ok := up.(*httpUpstreamService) + require.True(s.T(), ok, "expected *httpUpstreamService") + + got := svc.createProxyClient("://bad-proxy-url") + require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback") +} + +func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "direct") + })) + s.T().Cleanup(upstream.Close) + + up := NewHTTPUpstream(s.cfg) + + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil) + require.NoError(s.T(), err, "NewRequest") + resp, err := up.Do(req, "") + require.NoError(s.T(), err, "Do") + defer func() { _ = resp.Body.Close() }() + b, _ := io.ReadAll(resp.Body) + require.Equal(s.T(), "direct", string(b), "unexpected body") +} + +func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { + seen := make(chan string, 1) + proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen <- r.RequestURI + _, _ = io.WriteString(w, "proxied") + })) + s.T().Cleanup(proxySrv.Close) + + s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1} + up := NewHTTPUpstream(s.cfg) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) + require.NoError(s.T(), err, "NewRequest") + resp, err := up.Do(req, proxySrv.URL) + require.NoError(s.T(), err, "Do") + defer func() { _ = resp.Body.Close() }() + b, _ := io.ReadAll(resp.Body) + require.Equal(s.T(), "proxied", string(b), "unexpected body") + + select { + case uri := <-seen: + require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI") + default: + require.Fail(s.T(), "expected proxy to receive request") + } +} + +func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "direct-empty") + })) + s.T().Cleanup(upstream.Close) + + up := NewHTTPUpstream(s.cfg) + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil) + require.NoError(s.T(), err, "NewRequest") + resp, err := up.Do(req, "") + require.NoError(s.T(), err, "Do with empty proxy") + defer func() { _ = resp.Body.Close() }() + b, _ := io.ReadAll(resp.Body) + require.Equal(s.T(), "direct-empty", string(b)) +} + +func TestHTTPUpstreamSuite(t *testing.T) { + suite.Run(t, new(HTTPUpstreamSuite)) +} diff --git a/backend/internal/repository/identity_cache_integration_test.go b/backend/internal/repository/identity_cache_integration_test.go new file mode 100644 index 00000000..9452cb48 --- /dev/null +++ b/backend/internal/repository/identity_cache_integration_test.go @@ -0,0 +1,67 @@ +//go:build integration + +package repository + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service/ports" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type IdentityCacheSuite struct { + IntegrationRedisSuite + cache *identityCache +} + +func (s *IdentityCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewIdentityCache(s.rdb).(*identityCache) +} + +func (s *IdentityCacheSuite) TestGetFingerprint_Missing() { + _, err := s.cache.GetFingerprint(s.ctx, 1) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint") +} + +func (s *IdentityCacheSuite) TestSetAndGetFingerprint() { + fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"} + require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint") + gotFP, err := s.cache.GetFingerprint(s.ctx, 1) + require.NoError(s.T(), err, "GetFingerprint") + require.Equal(s.T(), "c1", gotFP.ClientID) + require.Equal(s.T(), "ua", gotFP.UserAgent) +} + +func (s *IdentityCacheSuite) TestFingerprint_TTL() { + fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"} + require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp)) + + fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2) + ttl, err := s.rdb.TTL(s.ctx, fpKey).Result() + require.NoError(s.T(), err, "TTL fpKey") + s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL) +} + +func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() { + fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999) + require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON") + + _, err := s.cache.GetFingerprint(s.ctx, 999) + require.Error(s.T(), err, "expected error for corrupted JSON") + require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil") +} + +func (s *IdentityCacheSuite) TestSetFingerprint_Nil() { + err := s.cache.SetFingerprint(s.ctx, 100, nil) + require.NoError(s.T(), err, "SetFingerprint(nil) should succeed") +} + +func TestIdentityCacheSuite(t *testing.T) { + suite.Run(t, new(IdentityCacheSuite)) +} diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go new file mode 100644 index 00000000..1588e078 --- /dev/null +++ b/backend/internal/repository/integration_harness_test.go @@ -0,0 +1,369 @@ +//go:build integration + +package repository + +import ( + "context" + "database/sql" + "fmt" + "log" + "os" + "os/exec" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + redisclient "github.com/redis/go-redis/v9" + tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" + gormpostgres "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +const ( + redisImageTag = "redis:8.4-alpine" + postgresImageTag = "postgres:18.1-alpine3.23" +) + +var ( + integrationDB *gorm.DB + integrationRedis *redisclient.Client + + redisNamespaceSeq uint64 +) + +func TestMain(m *testing.M) { + ctx := context.Background() + + if err := timezone.Init("UTC"); err != nil { + log.Printf("failed to init timezone: %v", err) + os.Exit(1) + } + + if !dockerIsAvailable(ctx) { + // In CI we expect Docker to be available so integration tests should fail loudly. + if os.Getenv("CI") != "" { + log.Printf("docker is not available (CI=true); failing integration tests") + os.Exit(1) + } + log.Printf("docker is not available; skipping integration tests (start Docker to enable)") + os.Exit(0) + } + + postgresImage := selectDockerImage(ctx, postgresImageTag) + pgContainer, err := tcpostgres.Run( + ctx, + postgresImage, + tcpostgres.WithDatabase("sub2api_test"), + tcpostgres.WithUsername("postgres"), + tcpostgres.WithPassword("postgres"), + tcpostgres.BasicWaitStrategies(), + ) + if err != nil { + log.Printf("failed to start postgres container: %v", err) + os.Exit(1) + } + defer func() { _ = pgContainer.Terminate(ctx) }() + + redisContainer, err := tcredis.Run( + ctx, + redisImageTag, + ) + if err != nil { + log.Printf("failed to start redis container: %v", err) + os.Exit(1) + } + defer func() { _ = redisContainer.Terminate(ctx) }() + + dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC") + if err != nil { + log.Printf("failed to get postgres dsn: %v", err) + os.Exit(1) + } + + integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second) + if err != nil { + log.Printf("failed to open gorm db: %v", err) + os.Exit(1) + } + if err := model.AutoMigrate(integrationDB); err != nil { + log.Printf("failed to automigrate db: %v", err) + os.Exit(1) + } + + redisHost, err := redisContainer.Host(ctx) + if err != nil { + log.Printf("failed to get redis host: %v", err) + os.Exit(1) + } + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + if err != nil { + log.Printf("failed to get redis port: %v", err) + os.Exit(1) + } + + integrationRedis = redisclient.NewClient(&redisclient.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + if err := integrationRedis.Ping(ctx).Err(); err != nil { + log.Printf("failed to ping redis: %v", err) + os.Exit(1) + } + + code := m.Run() + + _ = integrationRedis.Close() + + os.Exit(code) +} + +func dockerIsAvailable(ctx context.Context) bool { + cmd := exec.CommandContext(ctx, "docker", "info") + cmd.Env = os.Environ() + return cmd.Run() == nil +} + +func selectDockerImage(ctx context.Context, preferred string) string { + if dockerImageExists(ctx, preferred) { + return preferred + } + + return preferred +} + +func dockerImageExists(ctx context.Context, image string) bool { + cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image) + cmd.Env = os.Environ() + cmd.Stdout = nil + cmd.Stderr = nil + return cmd.Run() == nil +} + +func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) { + deadline := time.Now().Add(timeout) + var lastErr error + + for time.Now().Before(deadline) { + db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + lastErr = err + time.Sleep(250 * time.Millisecond) + continue + } + + sqlDB, err := db.DB() + if err != nil { + lastErr = err + time.Sleep(250 * time.Millisecond) + continue + } + + if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil { + lastErr = err + time.Sleep(250 * time.Millisecond) + continue + } + + return db, nil + } + + return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr) +} + +func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error { + pingCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return db.PingContext(pingCtx) +} + +func testTx(t *testing.T) *gorm.DB { + t.Helper() + + tx := integrationDB.Begin() + require.NoError(t, tx.Error, "begin tx") + t.Cleanup(func() { + _ = tx.Rollback().Error + }) + return tx +} + +func testRedis(t *testing.T) *redisclient.Client { + t.Helper() + + prefix := fmt.Sprintf( + "it:%s:%d:%d:", + sanitizeRedisNamespace(t.Name()), + time.Now().UnixNano(), + atomic.AddUint64(&redisNamespaceSeq, 1), + ) + + opts := *integrationRedis.Options() + rdb := redisclient.NewClient(&opts) + rdb.AddHook(prefixHook{prefix: prefix}) + + t.Cleanup(func() { + ctx := context.Background() + + var cursor uint64 + for { + keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result() + require.NoError(t, err, "scan redis keys for cleanup") + if len(keys) > 0 { + require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup") + } + + cursor = nextCursor + if cursor == 0 { + break + } + } + + _ = rdb.Close() + }) + + return rdb +} + +func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) { + t.Helper() + require.GreaterOrEqual(t, ttl, min, "ttl should be >= min") + require.LessOrEqual(t, ttl, max, "ttl should be <= max") +} + +func sanitizeRedisNamespace(name string) string { + name = strings.ReplaceAll(name, "/", "_") + name = strings.ReplaceAll(name, " ", "_") + return name +} + +type prefixHook struct { + prefix string +} + +func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next } + +func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook { + return func(ctx context.Context, cmd redisclient.Cmder) error { + h.prefixCmd(cmd) + return next(ctx, cmd) + } +} + +func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook { + return func(ctx context.Context, cmds []redisclient.Cmder) error { + for _, cmd := range cmds { + h.prefixCmd(cmd) + } + return next(ctx, cmds) + } +} + +func (h prefixHook) prefixCmd(cmd redisclient.Cmder) { + args := cmd.Args() + if len(args) < 2 { + return + } + + prefixOne := func(i int) { + if i < 0 || i >= len(args) { + return + } + + switch v := args[i].(type) { + case string: + if v != "" && !strings.HasPrefix(v, h.prefix) { + args[i] = h.prefix + v + } + case []byte: + s := string(v) + if s != "" && !strings.HasPrefix(s, h.prefix) { + args[i] = []byte(h.prefix + s) + } + } + } + + switch strings.ToLower(cmd.Name()) { + case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl", + "hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists": + prefixOne(1) + case "del", "unlink": + for i := 1; i < len(args); i++ { + prefixOne(i) + } + case "eval", "evalsha", "eval_ro", "evalsha_ro": + if len(args) < 3 { + return + } + numKeys, err := strconv.Atoi(fmt.Sprint(args[2])) + if err != nil || numKeys <= 0 { + return + } + for i := 0; i < numKeys && 3+i < len(args); i++ { + prefixOne(3 + i) + } + case "scan": + for i := 2; i+1 < len(args); i++ { + if strings.EqualFold(fmt.Sprint(args[i]), "match") { + prefixOne(i + 1) + break + } + } + } +} + +// IntegrationRedisSuite provides a base suite for Redis integration tests. +// Embedding suites should call SetupTest to initialize ctx and rdb. +type IntegrationRedisSuite struct { + suite.Suite + ctx context.Context + rdb *redisclient.Client +} + +// SetupTest initializes ctx and rdb for each test method. +func (s *IntegrationRedisSuite) SetupTest() { + s.ctx = context.Background() + s.rdb = testRedis(s.T()) +} + +// RequireNoError is a convenience method wrapping require.NoError with s.T(). +func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) { + s.T().Helper() + require.NoError(s.T(), err, msgAndArgs...) +} + +// AssertTTLWithin asserts that ttl is within [min, max]. +func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) { + s.T().Helper() + assertTTLWithin(s.T(), ttl, min, max) +} + +// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests. +// Embedding suites should call SetupTest to initialize ctx and db. +type IntegrationDBSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB +} + +// SetupTest initializes ctx and db for each test method. +func (s *IntegrationDBSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) +} + +// RequireNoError is a convenience method wrapping require.NoError with s.T(). +func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) { + s.T().Helper() + require.NoError(s.T(), err, msgAndArgs...) +} diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 63e84af7..846b7b93 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -12,11 +12,13 @@ import ( "github.com/imroc/req/v3" ) -type openaiOAuthService struct{} - // NewOpenAIOAuthClient creates a new OpenAI OAuth client func NewOpenAIOAuthClient() ports.OpenAIOAuthClient { - return &openaiOAuthService{} + return &openaiOAuthService{tokenURL: openai.TokenURL} +} + +type openaiOAuthService struct { + tokenURL string } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { @@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie SetContext(ctx). SetFormDataFromValues(formData). SetSuccessResult(&tokenResp). - Post(openai.TokenURL) + Post(s.tokenURL) if err != nil { return nil, fmt.Errorf("request failed: %w", err) @@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro SetContext(ctx). SetFormDataFromValues(formData). SetSuccessResult(&tokenResp). - Post(openai.TokenURL) + Post(s.tokenURL) if err != nil { return nil, fmt.Errorf("request failed: %w", err) diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go new file mode 100644 index 00000000..0a5322d7 --- /dev/null +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -0,0 +1,249 @@ +package repository + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type OpenAIOAuthServiceSuite struct { + suite.Suite + ctx context.Context + srv *httptest.Server + svc *openaiOAuthService + received chan url.Values +} + +func (s *OpenAIOAuthServiceSuite) SetupTest() { + s.ctx = context.Background() + s.received = make(chan url.Values, 1) +} + +func (s *OpenAIOAuthServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) { + s.srv = httptest.NewServer(handler) + s.svc = &openaiOAuthService{tokenURL: s.srv.URL} +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errCh <- "method mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if err := r.ParseForm(); err != nil { + errCh <- "ParseForm failed" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("grant_type"); got != "authorization_code" { + errCh <- "grant_type mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("client_id"); got != openai.ClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("code"); got != "code" { + errCh <- "code mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI { + errCh <- "redirect_uri mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("code_verifier"); got != "ver" { + errCh <- "code_verifier mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } + require.Equal(s.T(), "at", resp.AccessToken) + require.Equal(s.T(), "rt", resp.RefreshToken) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + errCh <- "ParseForm failed" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("grant_type"); got != "refresh_token" { + errCh <- "grant_type mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("refresh_token"); got != "rt" { + errCh <- "refresh_token mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("client_id"); got != openai.ClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + if got := r.PostForm.Get("scope"); got != openai.RefreshScopes { + errCh <- "scope mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } + require.Equal(s.T(), "at2", resp.AccessToken) + require.Equal(s.T(), "rt2", resp.RefreshToken) +} + +func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "bad") + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "status 400") + require.ErrorContains(s.T(), err, "bad") +} + +func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + s.srv.Close() + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "request failed") +} + +func (s *OpenAIOAuthServiceSuite) TestContextCancel() { + started := make(chan struct{}) + block := make(chan struct{}) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(started) + <-block + })) + + ctx, cancel := context.WithCancel(s.ctx) + + done := make(chan error, 1) + go func() { + _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") + done <- err + }() + + <-started + cancel() + close(block) + + err := <-done + require.Error(s.T(), err) +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { + want := "http://localhost:9999/cb" + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if got := r.PostForm.Get("redirect_uri"); got != want { + errCh <- "redirect_uri mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } +} + +func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + s.received <- r.PostForm + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + s.svc.tokenURL = s.srv.URL + "?x=1" + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case <-s.received: + default: + require.Fail(s.T(), "expected server to receive request") + } +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "not-valid-json") + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + require.Error(s.T(), err, "expected error for invalid JSON response") +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = io.WriteString(w, "unauthorized") + })) + + _, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.Error(s.T(), err, "expected error for non-2xx status") + require.ErrorContains(s.T(), err, "status 401") +} + +func TestOpenAIOAuthServiceSuite(t *testing.T) { + suite.Run(t, new(OpenAIOAuthServiceSuite)) +} diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go new file mode 100644 index 00000000..8cfc8222 --- /dev/null +++ b/backend/internal/repository/pricing_service_test.go @@ -0,0 +1,147 @@ +package repository + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type PricingServiceSuite struct { + suite.Suite + ctx context.Context + srv *httptest.Server + client *pricingRemoteClient +} + +func (s *PricingServiceSuite) SetupTest() { + s.ctx = context.Background() + client, ok := NewPricingRemoteClient().(*pricingRemoteClient) + require.True(s.T(), ok, "type assertion failed") + s.client = client +} + +func (s *PricingServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) { + s.srv = httptest.NewServer(handler) +} + +func (s *PricingServiceSuite) TestFetchPricingJSON_Success() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ok" { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + return + } + w.WriteHeader(http.StatusInternalServerError) + })) + + body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok") + require.NoError(s.T(), err, "FetchPricingJSON") + require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch") +} + +func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + + _, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err") + require.Error(s.T(), err, "expected error for non-200 status") +} + +func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/hashfile": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("abc123 model_prices.json\n")) + case "/hashonly": + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("def456\n")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + + hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile") + require.NoError(s.T(), err, "FetchHashText") + require.Equal(s.T(), "abc123", hash, "hash mismatch") + + hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly") + require.NoError(s.T(), err, "FetchHashText") + require.Equal(s.T(), "def456", hash2, "hash mismatch") +} + +func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + + _, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope") + require.Error(s.T(), err, "expected error for non-200 status") +} + +func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() { + _, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url") + require.Error(s.T(), err, "expected error for invalid URL") +} + +func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + // empty body + })) + + hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty") + require.NoError(s.T(), err, "FetchHashText empty body should not error") + require.Equal(s.T(), "", hash, "expected empty hash") +} + +func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(" \n")) + })) + + hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws") + require.NoError(s.T(), err, "FetchHashText whitespace body should not error") + require.Equal(s.T(), "", hash, "expected empty hash after trimming") +} + +func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { + started := make(chan struct{}) + block := make(chan struct{}) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(started) + <-block + })) + + ctx, cancel := context.WithCancel(s.ctx) + + done := make(chan error, 1) + go func() { + _, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block") + done <- err + }() + + <-started + cancel() + close(block) + + err := <-done + require.Error(s.T(), err) +} + +func TestPricingServiceSuite(t *testing.T) { + suite.Run(t, new(PricingServiceSuite)) +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 0e3a0934..9331859c 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -16,10 +16,14 @@ import ( "golang.org/x/net/proxy" ) -type proxyProbeService struct{} - func NewProxyExitInfoProber() service.ProxyExitInfoProber { - return &proxyProbeService{} + return &proxyProbeService{ipInfoURL: defaultIPInfoURL} +} + +const defaultIPInfoURL = "https://ipinfo.io/json" + +type proxyProbeService struct { + ipInfoURL string } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s } startTime := time.Now() - req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil) + req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) if err != nil { return nil, 0, fmt.Errorf("failed to create request: %w", err) } diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go new file mode 100644 index 00000000..25ab0f9c --- /dev/null +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -0,0 +1,121 @@ +package repository + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ProxyProbeServiceSuite struct { + suite.Suite + ctx context.Context + proxySrv *httptest.Server + prober *proxyProbeService +} + +func (s *ProxyProbeServiceSuite) SetupTest() { + s.ctx = context.Background() + s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"} +} + +func (s *ProxyProbeServiceSuite) TearDownTest() { + if s.proxySrv != nil { + s.proxySrv.Close() + s.proxySrv = nil + } +} + +func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) { + s.proxySrv = httptest.NewServer(handler) +} + +func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() { + _, err := createProxyTransport("://bad") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "invalid proxy URL") +} + +func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() { + _, err := createProxyTransport("ftp://127.0.0.1:1") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "unsupported proxy protocol") +} + +func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() { + tr, err := createProxyTransport("socks5://127.0.0.1:1080") + require.NoError(s.T(), err, "createProxyTransport") + require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5") +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { + seen := make(chan string, 1) + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen <- r.RequestURI + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`) + })) + + info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.NoError(s.T(), err, "ProbeProxy") + require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency") + require.Equal(s.T(), "1.2.3.4", info.IP) + require.Equal(s.T(), "c", info.City) + require.Equal(s.T(), "r", info.Region) + require.Equal(s.T(), "cc", info.Country) + + // Verify proxy received the request + select { + case uri := <-seen: + require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy") + default: + require.Fail(s.T(), "expected proxy to receive request") + } +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() { + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + + _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "status: 503") +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() { + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + })) + + _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "failed to parse response") +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() { + s.prober.ipInfoURL = "://invalid-url" + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.Error(s.T(), err, "expected error for invalid ipInfoURL") +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() { + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + s.proxySrv.Close() + + _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.Error(s.T(), err, "expected error when proxy server is closed") +} + +func TestProxyProbeServiceSuite(t *testing.T) { + suite.Run(t, new(ProxyProbeServiceSuite)) +} diff --git a/backend/internal/repository/proxy_repo_integration_test.go b/backend/internal/repository/proxy_repo_integration_test.go new file mode 100644 index 00000000..67c1825f --- /dev/null +++ b/backend/internal/repository/proxy_repo_integration_test.go @@ -0,0 +1,302 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type ProxyRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *ProxyRepository +} + +func (s *ProxyRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewProxyRepository(s.db) +} + +func TestProxyRepoSuite(t *testing.T) { + suite.Run(t, new(ProxyRepoSuite)) +} + +// --- Create / GetByID / Update / Delete --- + +func (s *ProxyRepoSuite) TestCreate() { + proxy := &model.Proxy{ + Name: "test-create", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Status: model.StatusActive, + } + + err := s.repo.Create(s.ctx, proxy) + s.Require().NoError(err, "Create") + s.Require().NotZero(proxy.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, proxy.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("test-create", got.Name) +} + +func (s *ProxyRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *ProxyRepoSuite) TestUpdate() { + proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"}) + + proxy.Name = "updated" + err := s.repo.Update(s.ctx, proxy) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, proxy.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("updated", got.Name) +} + +func (s *ProxyRepoSuite) TestDelete() { + proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"}) + + err := s.repo.Delete(s.ctx, proxy.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, proxy.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- List / ListWithFilters --- + +func (s *ProxyRepoSuite) TestList() { + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) + + proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List") + s.Require().Len(proxies, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"}) + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"}) + + proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "") + s.Require().NoError(err) + s.Require().Len(proxies, 1) + s.Require().Equal("socks5", proxies[0].Protocol) +} + +func (s *ProxyRepoSuite) TestListWithFilters_Status() { + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive}) + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled}) + + proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "") + s.Require().NoError(err) + s.Require().Len(proxies, 1) + s.Require().Equal(model.StatusDisabled, proxies[0].Status) +} + +func (s *ProxyRepoSuite) TestListWithFilters_Search() { + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"}) + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"}) + + proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod") + s.Require().NoError(err) + s.Require().Len(proxies, 1) + s.Require().Contains(proxies[0].Name, "production") +} + +// --- ListActive --- + +func (s *ProxyRepoSuite) TestListActive() { + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive}) + mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled}) + + proxies, err := s.repo.ListActive(s.ctx) + s.Require().NoError(err, "ListActive") + s.Require().Len(proxies, 1) + s.Require().Equal("active1", proxies[0].Name) +} + +// --- ExistsByHostPortAuth --- + +func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { + mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p1", + Protocol: "http", + Host: "1.2.3.4", + Port: 8080, + Username: "user", + Password: "pass", + }) + + exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass") + s.Require().NoError(err, "ExistsByHostPortAuth") + s.Require().True(exists) + + notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds") + s.Require().NoError(err) + s.Require().False(notExists) +} + +func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { + mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p-noauth", + Protocol: "http", + Host: "5.6.7.8", + Port: 8081, + Username: "", + Password: "", + }) + + exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "") + s.Require().NoError(err) + s.Require().True(exists) +} + +// --- CountAccountsByProxyID --- + +func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { + proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy + + count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) + s.Require().NoError(err, "CountAccountsByProxyID") + s.Require().Equal(int64(2), count) +} + +func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { + proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"}) + + count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) + s.Require().NoError(err) + s.Require().Zero(count) +} + +// --- GetAccountCountsForProxies --- + +func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() { + p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) + p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) + + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + + counts, err := s.repo.GetAccountCountsForProxies(s.ctx) + s.Require().NoError(err, "GetAccountCountsForProxies") + s.Require().Equal(int64(2), counts[p1.ID]) + s.Require().Equal(int64(1), counts[p2.ID]) +} + +func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() { + counts, err := s.repo.GetAccountCountsForProxies(s.ctx) + s.Require().NoError(err) + s.Require().Empty(counts) +} + +// --- ListActiveWithAccountCount --- + +func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { + base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p1", + Status: model.StatusActive, + CreatedAt: base.Add(-1 * time.Hour), + }) + p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p2", + Status: model.StatusActive, + CreatedAt: base, + }) + mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p3-inactive", + Status: model.StatusDisabled, + }) + + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + + withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) + s.Require().NoError(err, "ListActiveWithAccountCount") + s.Require().Len(withCounts, 2, "expected 2 active proxies") + + // Sorted by created_at DESC, so p2 first + s.Require().Equal(p2.ID, withCounts[0].ID) + s.Require().Equal(int64(1), withCounts[0].AccountCount) + s.Require().Equal(p1.ID, withCounts[1].ID) + s.Require().Equal(int64(2), withCounts[1].AccountCount) +} + +// --- Combined original test --- + +func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { + p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p1", + Protocol: "http", + Host: "1.2.3.4", + Port: 8080, + Username: "u", + Password: "p", + CreatedAt: time.Now().Add(-1 * time.Hour), + UpdatedAt: time.Now().Add(-1 * time.Hour), + }) + p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + Name: "p2", + Protocol: "http", + Host: "5.6.7.8", + Port: 8081, + Username: "", + Password: "", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + + exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p") + s.Require().NoError(err, "ExistsByHostPortAuth") + s.Require().True(exists, "expected proxy to exist") + + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + + count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID) + s.Require().NoError(err, "CountAccountsByProxyID") + s.Require().Equal(int64(2), count1, "expected 2 accounts for p1") + + counts, err := s.repo.GetAccountCountsForProxies(s.ctx) + s.Require().NoError(err, "GetAccountCountsForProxies") + s.Require().Equal(int64(2), counts[p1.ID]) + s.Require().Equal(int64(1), counts[p2.ID]) + + withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) + s.Require().NoError(err, "ListActiveWithAccountCount") + s.Require().Len(withCounts, 2, "expected 2 proxies") + for _, pc := range withCounts { + switch pc.ID { + case p1.ID: + s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch") + case p2.ID: + s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch") + default: + s.Require().Fail("unexpected proxy id", pc.ID) + } + } +} diff --git a/backend/internal/repository/redeem_cache_integration_test.go b/backend/internal/repository/redeem_cache_integration_test.go new file mode 100644 index 00000000..a7aa05d9 --- /dev/null +++ b/backend/internal/repository/redeem_cache_integration_test.go @@ -0,0 +1,105 @@ +//go:build integration + +package repository + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type RedeemCacheSuite struct { + IntegrationRedisSuite + cache *redeemCache +} + +func (s *RedeemCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewRedeemCache(s.rdb).(*redeemCache) +} + +func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() { + missingUserID := int64(99999) + _, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID) + require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key") + require.True(s.T(), errors.Is(err, redis.Nil)) +} + +func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() { + userID := int64(1) + key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) + + require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount") + count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID) + require.NoError(s.T(), err, "GetRedeemAttemptCount") + require.Equal(s.T(), 1, count, "count mismatch") + + ttl, err := s.rdb.TTL(s.ctx, key).Result() + require.NoError(s.T(), err, "TTL") + s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration) +} + +func (s *RedeemCacheSuite) TestMultipleIncrements() { + userID := int64(2) + + require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID)) + require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID)) + require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID)) + + count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID) + require.NoError(s.T(), err) + require.Equal(s.T(), 3, count, "count after 3 increments") +} + +func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() { + ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second) + require.NoError(s.T(), err, "AcquireRedeemLock") + require.True(s.T(), ok) + + // Second acquire should fail + ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second) + require.NoError(s.T(), err, "AcquireRedeemLock 2") + require.False(s.T(), ok, "expected lock to be held") + + // Release + require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock") + + // Now acquire should succeed + ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second) + require.NoError(s.T(), err, "AcquireRedeemLock after release") + require.True(s.T(), ok) +} + +func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() { + lockKey := redeemLockKeyPrefix + "CODE2" + lockTTL := 15 * time.Second + + ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL) + require.NoError(s.T(), err, "AcquireRedeemLock CODE2") + require.True(s.T(), ok) + + ttl, err := s.rdb.TTL(s.ctx, lockKey).Result() + require.NoError(s.T(), err, "TTL lock key") + s.AssertTTLWithin(ttl, 1*time.Second, lockTTL) +} + +func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() { + // Release a lock that doesn't exist should not error + require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT")) + + // Acquire, release, release again + ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second) + require.NoError(s.T(), err) + require.True(s.T(), ok) + require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT")) + require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent") +} + +func TestRedeemCacheSuite(t *testing.T) { + suite.Run(t, new(RedeemCacheSuite)) +} diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go new file mode 100644 index 00000000..f39d6a51 --- /dev/null +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -0,0 +1,315 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type RedeemCodeRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *RedeemCodeRepository +} + +func (s *RedeemCodeRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewRedeemCodeRepository(s.db) +} + +func TestRedeemCodeRepoSuite(t *testing.T) { + suite.Run(t, new(RedeemCodeRepoSuite)) +} + +// --- Create / CreateBatch / GetByID / GetByCode --- + +func (s *RedeemCodeRepoSuite) TestCreate() { + code := &model.RedeemCode{ + Code: "TEST-CREATE", + Type: model.RedeemTypeBalance, + Value: 100, + Status: model.StatusUnused, + } + + err := s.repo.Create(s.ctx, code) + s.Require().NoError(err, "Create") + s.Require().NotZero(code.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, code.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("TEST-CREATE", got.Code) +} + +func (s *RedeemCodeRepoSuite) TestCreateBatch() { + codes := []model.RedeemCode{ + {Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused}, + {Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused}, + } + + err := s.repo.CreateBatch(s.ctx, codes) + s.Require().NoError(err, "CreateBatch") + + got1, err := s.repo.GetByCode(s.ctx, "BATCH-1") + s.Require().NoError(err) + s.Require().Equal(float64(10), got1.Value) + + got2, err := s.repo.GetByCode(s.ctx, "BATCH-2") + s.Require().NoError(err) + s.Require().Equal(float64(20), got2.Value) +} + +func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *RedeemCodeRepoSuite) TestGetByCode() { + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance}) + + got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE") + s.Require().NoError(err, "GetByCode") + s.Require().Equal("GET-BY-CODE", got.Code) +} + +func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() { + _, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT") + s.Require().Error(err, "expected error for non-existent code") +} + +// --- Delete --- + +func (s *RedeemCodeRepoSuite) TestDelete() { + code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance}) + + err := s.repo.Delete(s.ctx, code.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, code.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- List / ListWithFilters --- + +func (s *RedeemCodeRepoSuite) TestList() { + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance}) + + codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List") + s.Require().Len(codes, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() { + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription}) + + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "") + s.Require().NoError(err) + s.Require().Len(codes, 1) + s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type) +} + +func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) + + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "") + s.Require().NoError(err) + s.Require().Len(codes, 1) + s.Require().Equal(model.StatusUsed, codes[0].Status) +} + +func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance}) + + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha") + s.Require().NoError(err) + s.Require().Len(codes, 1) + s.Require().Contains(codes[0].Code, "ALPHA") +} + +func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) + mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + Code: "WITH-GROUP", + Type: model.RedeemTypeSubscription, + GroupID: &group.ID, + }) + + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "") + s.Require().NoError(err) + s.Require().Len(codes, 1) + s.Require().NotNil(codes[0].Group, "expected Group preload") + s.Require().Equal(group.ID, codes[0].Group.ID) +} + +// --- Update --- + +func (s *RedeemCodeRepoSuite) TestUpdate() { + code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10}) + + code.Value = 50 + err := s.repo.Update(s.ctx, code) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, code.ID) + s.Require().NoError(err) + s.Require().Equal(float64(50), got.Value) +} + +// --- Use --- + +func (s *RedeemCodeRepoSuite) TestUse() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) + + err := s.repo.Use(s.ctx, code.ID, user.ID) + s.Require().NoError(err, "Use") + + got, err := s.repo.GetByID(s.ctx, code.ID) + s.Require().NoError(err) + s.Require().Equal(model.StatusUsed, got.Status) + s.Require().NotNil(got.UsedBy) + s.Require().Equal(user.ID, *got.UsedBy) + s.Require().NotNil(got.UsedAt) +} + +func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) + + err := s.repo.Use(s.ctx, code.ID, user.ID) + s.Require().NoError(err, "Use first time") + + // Second use should fail + err = s.repo.Use(s.ctx, code.ID, user.ID) + s.Require().Error(err, "Use expected error on second call") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) +} + +func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) + + err := s.repo.Use(s.ctx, code.ID, user.ID) + s.Require().Error(err, "expected error for already used code") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) +} + +// --- ListByUser --- + +func (s *RedeemCodeRepoSuite) TestListByUser() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) + base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + // Create codes with explicit used_at for ordering + c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + Code: "USER-1", + Type: model.RedeemTypeBalance, + Status: model.StatusUsed, + UsedBy: &user.ID, + }) + s.db.Model(c1).Update("used_at", base) + + c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + Code: "USER-2", + Type: model.RedeemTypeBalance, + Status: model.StatusUsed, + UsedBy: &user.ID, + }) + s.db.Model(c2).Update("used_at", base.Add(1*time.Hour)) + + codes, err := s.repo.ListByUser(s.ctx, user.ID, 10) + s.Require().NoError(err, "ListByUser") + s.Require().Len(codes, 2) + // Ordered by used_at DESC, so USER-2 first + s.Require().Equal("USER-2", codes[0].Code) + s.Require().Equal("USER-1", codes[1].Code) +} + +func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"}) + + c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + Code: "WITH-GRP", + Type: model.RedeemTypeSubscription, + Status: model.StatusUsed, + UsedBy: &user.ID, + GroupID: &group.ID, + }) + s.db.Model(c).Update("used_at", time.Now()) + + codes, err := s.repo.ListByUser(s.ctx, user.ID, 10) + s.Require().NoError(err) + s.Require().Len(codes, 1) + s.Require().NotNil(codes[0].Group) + s.Require().Equal(group.ID, codes[0].Group.ID) +} + +func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"}) + c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + Code: "DEF-LIM", + Type: model.RedeemTypeBalance, + Status: model.StatusUsed, + UsedBy: &user.ID, + }) + s.db.Model(c).Update("used_at", time.Now()) + + // limit <= 0 should default to 10 + codes, err := s.repo.ListByUser(s.ctx, user.ID, 0) + s.Require().NoError(err) + s.Require().Len(codes, 1) +} + +// --- Combined original test --- + +func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"}) + + codes := []model.RedeemCode{ + {Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()}, + {Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, + } + s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch") + + list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code") + s.Require().NoError(err, "ListWithFilters") + s.Require().Equal(int64(1), page.Total) + s.Require().Len(list, 1) + s.Require().NotNil(list[0].Group, "expected Group preload") + s.Require().Equal(group.ID, list[0].Group.ID) + + codeB, err := s.repo.GetByCode(s.ctx, "CODEB") + s.Require().NoError(err, "GetByCode") + s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use") + err = s.repo.Use(s.ctx, codeB.ID, user.ID) + s.Require().Error(err, "Use expected error on second call") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + + codeA, err := s.repo.GetByCode(s.ctx, "CODEA") + s.Require().NoError(err, "GetByCode") + + // Use fixed time instead of time.Sleep for deterministic ordering + s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) + s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA") + s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) + + used, err := s.repo.ListByUser(s.ctx, user.ID, 10) + s.Require().NoError(err, "ListByUser") + s.Require().Len(used, 2, "expected 2 used codes") + s.Require().Equal("CODEA", used[0].Code, "expected newest used code first") +} diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go new file mode 100644 index 00000000..b42cacd7 --- /dev/null +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -0,0 +1,108 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type SettingRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *SettingRepository +} + +func (s *SettingRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewSettingRepository(s.db) +} + +func TestSettingRepoSuite(t *testing.T) { + suite.Run(t, new(SettingRepoSuite)) +} + +func (s *SettingRepoSuite) TestSetAndGetValue() { + s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set") + got, err := s.repo.GetValue(s.ctx, "k1") + s.Require().NoError(err, "GetValue") + s.Require().Equal("v1", got, "GetValue mismatch") +} + +func (s *SettingRepoSuite) TestSet_Upsert() { + s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set") + s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert") + got, err := s.repo.GetValue(s.ctx, "k1") + s.Require().NoError(err, "GetValue after upsert") + s.Require().Equal("v2", got, "upsert mismatch") +} + +func (s *SettingRepoSuite) TestGetValue_Missing() { + _, err := s.repo.GetValue(s.ctx, "nonexistent") + s.Require().Error(err, "expected error for missing key") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) +} + +func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() { + s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple") + m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"}) + s.Require().NoError(err, "GetMultiple") + s.Require().Equal("v2", m["k2"]) + s.Require().Equal("v3", m["k3"]) +} + +func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() { + m, err := s.repo.GetMultiple(s.ctx, []string{}) + s.Require().NoError(err, "GetMultiple with empty keys") + s.Require().Empty(m, "expected empty map") +} + +func (s *SettingRepoSuite) TestGetMultiple_Subset() { + s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"})) + m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"}) + s.Require().NoError(err, "GetMultiple subset") + s.Require().Equal("1", m["a"]) + s.Require().Equal("3", m["c"]) + _, exists := m["nonexistent"] + s.Require().False(exists, "nonexistent key should not be in map") +} + +func (s *SettingRepoSuite) TestGetAll() { + s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"})) + all, err := s.repo.GetAll(s.ctx) + s.Require().NoError(err, "GetAll") + s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings") + s.Require().Equal("1", all["x"]) + s.Require().Equal("2", all["y"]) +} + +func (s *SettingRepoSuite) TestDelete() { + s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val")) + s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete") + _, err := s.repo.GetValue(s.ctx, "todelete") + s.Require().Error(err, "expected missing key error after Delete") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) +} + +func (s *SettingRepoSuite) TestDelete_Idempotent() { + // Delete a key that doesn't exist should not error + s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent") +} + +func (s *SettingRepoSuite) TestSetMultiple_Upsert() { + s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value")) + s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"})) + + got, err := s.repo.GetValue(s.ctx, "upsert_key") + s.Require().NoError(err) + s.Require().Equal("new_value", got, "SetMultiple should upsert existing key") + + got2, err := s.repo.GetValue(s.ctx, "new_key") + s.Require().NoError(err) + s.Require().Equal("new_val", got2) +} diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index 77f33bba..c3755011 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev type turnstileVerifier struct { httpClient *http.Client + verifyURL string } func NewTurnstileVerifier() service.TurnstileVerifier { @@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier { httpClient: &http.Client{ Timeout: 10 * time.Second, }, + verifyURL: turnstileVerifyURL, } } @@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r formData.Set("remoteip", remoteIP) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode())) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode())) if err != nil { return nil, fmt.Errorf("create request: %w", err) } diff --git a/backend/internal/repository/turnstile_service_test.go b/backend/internal/repository/turnstile_service_test.go new file mode 100644 index 00000000..3876a007 --- /dev/null +++ b/backend/internal/repository/turnstile_service_test.go @@ -0,0 +1,143 @@ +package repository + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type TurnstileServiceSuite struct { + suite.Suite + ctx context.Context + srv *httptest.Server + verifier *turnstileVerifier + received chan url.Values +} + +func (s *TurnstileServiceSuite) SetupTest() { + s.ctx = context.Background() + s.received = make(chan url.Values, 1) + verifier, ok := NewTurnstileVerifier().(*turnstileVerifier) + require.True(s.T(), ok, "type assertion failed") + s.verifier = verifier +} + +func (s *TurnstileServiceSuite) TearDownTest() { + if s.srv != nil { + s.srv.Close() + s.srv = nil + } +} + +func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) { + s.srv = httptest.NewServer(handler) + s.verifier.verifyURL = s.srv.URL +} + +func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture form data in main goroutine context later + body, _ := io.ReadAll(r.Body) + values, _ := url.ParseQuery(string(body)) + s.received <- values + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) + })) + + resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") + require.NoError(s.T(), err, "VerifyToken") + require.NotNil(s.T(), resp) + require.True(s.T(), resp.Success, "expected success response") + + // Assert form fields in main goroutine + select { + case values := <-s.received: + require.Equal(s.T(), "sk", values.Get("secret")) + require.Equal(s.T(), "token", values.Get("response")) + require.Equal(s.T(), "1.1.1.1", values.Get("remoteip")) + default: + require.Fail(s.T(), "expected server to receive request") + } +} + +func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() { + var contentType string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + contentType = r.Header.Get("Content-Type") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) + })) + + _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") + require.NoError(s.T(), err) + require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType) +} + +func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + values, _ := url.ParseQuery(string(body)) + s.received <- values + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true}) + })) + + _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "") + require.NoError(s.T(), err) + + select { + case values := <-s.received: + require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent") + default: + require.Fail(s.T(), "expected server to receive request") + } +} + +func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + s.srv.Close() + + _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") + require.Error(s.T(), err, "expected error when server is closed") +} + +func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-valid-json") + })) + + _, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") + require.Error(s.T(), err, "expected error for invalid JSON response") +} + +func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{ + Success: false, + ErrorCodes: []string{"invalid-input-response"}, + }) + })) + + resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1") + require.NoError(s.T(), err, "VerifyToken should not error on success=false") + require.NotNil(s.T(), resp) + require.False(s.T(), resp.Success) + require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response") +} + +func TestTurnstileServiceSuite(t *testing.T) { + suite.Run(t, new(TurnstileServiceSuite)) +} diff --git a/backend/internal/repository/update_cache_integration_test.go b/backend/internal/repository/update_cache_integration_test.go new file mode 100644 index 00000000..792f1b17 --- /dev/null +++ b/backend/internal/repository/update_cache_integration_test.go @@ -0,0 +1,73 @@ +//go:build integration + +package repository + +import ( + "errors" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type UpdateCacheSuite struct { + IntegrationRedisSuite + cache *updateCache +} + +func (s *UpdateCacheSuite) SetupTest() { + s.IntegrationRedisSuite.SetupTest() + s.cache = NewUpdateCache(s.rdb).(*updateCache) +} + +func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() { + _, err := s.cache.GetUpdateInfo(s.ctx) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info") +} + +func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() { + updateTTL := 5 * time.Minute + require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo") + + info, err := s.cache.GetUpdateInfo(s.ctx) + require.NoError(s.T(), err, "GetUpdateInfo") + require.Equal(s.T(), "v1.2.3", info, "update info mismatch") +} + +func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() { + updateTTL := 5 * time.Minute + require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL)) + + ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result() + require.NoError(s.T(), err, "TTL updateCacheKey") + s.AssertTTLWithin(ttl, 1*time.Second, updateTTL) +} + +func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() { + require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute)) + require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute)) + + info, err := s.cache.GetUpdateInfo(s.ctx) + require.NoError(s.T(), err) + require.Equal(s.T(), "v2.0.0", info, "expected overwritten value") +} + +func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() { + // TTL=0 means persist forever (no expiry) in Redis SET command + require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0)) + + info, err := s.cache.GetUpdateInfo(s.ctx) + require.NoError(s.T(), err) + require.Equal(s.T(), "v0.0.0", info) + + ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result() + require.NoError(s.T(), err) + // TTL=-1 means no expiry, TTL=-2 means key doesn't exist + require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry") +} + +func TestUpdateCacheSuite(t *testing.T) { + suite.Run(t, new(UpdateCacheSuite)) +} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go new file mode 100644 index 00000000..76265d31 --- /dev/null +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -0,0 +1,890 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type UsageLogRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *UsageLogRepository +} + +func (s *UsageLogRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewUsageLogRepository(s.db) +} + +func TestUsageLogRepoSuite(t *testing.T) { + suite.Run(t, new(UsageLogRepoSuite)) +} + +func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog { + log := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalCost: cost, + ActualCost: cost, + CreatedAt: createdAt, + } + s.Require().NoError(s.repo.Create(s.ctx, log)) + return log +} + +// --- Create / GetByID --- + +func (s *UsageLogRepoSuite) TestCreate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"}) + + log := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.4, + } + + err := s.repo.Create(s.ctx, log) + s.Require().NoError(err, "Create") + s.Require().NotZero(log.ID) +} + +func (s *UsageLogRepoSuite) TestGetByID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"}) + + log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(log.ID, got.ID) + s.Require().Equal(10, got.InputTokens) +} + +func (s *UsageLogRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +// --- Delete --- + +func (s *UsageLogRepoSuite) TestDelete() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"}) + + log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + err := s.repo.Delete(s.ctx, log.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, log.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- ListByUser --- + +func (s *UsageLogRepoSuite) TestListByUser() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) + + logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByUser") + s.Require().Len(logs, 2) + s.Require().Equal(int64(2), page.Total) +} + +// --- ListByApiKey --- + +func (s *UsageLogRepoSuite) TestListByApiKey() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) + + logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByApiKey") + s.Require().Len(logs, 2) + s.Require().Equal(int64(2), page.Total) +} + +// --- ListByAccount --- + +func (s *UsageLogRepoSuite) TestListByAccount() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByAccount") + s.Require().Len(logs, 1) + s.Require().Equal(int64(1), page.Total) +} + +// --- GetUserStats --- + +func (s *UsageLogRepoSuite) TestGetUserStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime) + s.Require().NoError(err, "GetUserStats") + s.Require().Equal(int64(2), stats.TotalRequests) + s.Require().Equal(int64(25), stats.InputTokens) + s.Require().Equal(int64(45), stats.OutputTokens) +} + +// --- ListWithFilters --- + +func (s *UsageLogRepoSuite) TestListWithFilters() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + filters := usagestats.UsageLogFilters{UserID: user.ID} + logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) + s.Require().NoError(err, "ListWithFilters") + s.Require().Len(logs, 1) + s.Require().Equal(int64(1), page.Total) +} + +// --- GetDashboardStats --- + +func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { + now := time.Now() + todayStart := timezone.Today() + + userToday := mustCreateUser(s.T(), s.db, &model.User{ + Email: "today@example.com", + CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)), + UpdatedAt: now, + }) + userOld := mustCreateUser(s.T(), s.db, &model.User{ + Email: "old@example.com", + CreatedAt: todayStart.Add(-24 * time.Hour), + UpdatedAt: todayStart.Add(-24 * time.Hour), + }) + + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) + mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled}) + + resetAt := now.Add(10 * time.Minute) + accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) + + d1, d2, d3 := 100, 200, 300 + logToday := &model.UsageLog{ + UserID: userToday.ID, + ApiKeyID: apiKey1.ID, + AccountID: accNormal.ID, + Model: "claude-3", + GroupID: &group.ID, + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 3, + CacheReadTokens: 4, + TotalCost: 1.5, + ActualCost: 1.2, + DurationMs: &d1, + CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)), + } + s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") + + logOld := &model.UsageLog{ + UserID: userOld.ID, + ApiKeyID: apiKey1.ID, + AccountID: accNormal.ID, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 6, + TotalCost: 0.7, + ActualCost: 0.7, + DurationMs: &d2, + CreatedAt: todayStart.Add(-1 * time.Hour), + } + s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") + + logPerf := &model.UsageLog{ + UserID: userToday.ID, + ApiKeyID: apiKey1.ID, + AccountID: accNormal.ID, + Model: "claude-3", + InputTokens: 1, + OutputTokens: 2, + TotalCost: 0.1, + ActualCost: 0.1, + DurationMs: &d3, + CreatedAt: now.Add(-30 * time.Second), + } + s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf") + + stats, err := s.repo.GetDashboardStats(s.ctx) + s.Require().NoError(err, "GetDashboardStats") + + s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch") + s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch") + s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch") + s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch") + s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch") + s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch") + s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch") + s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch") + s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch") + + s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch") + s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch") + s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch") + s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch") + s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch") + s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch") + s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch") + s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch") + s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1") + s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0") + + wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0) + s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch") + s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch") +} + +// --- GetUserDashboardStats --- + +func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID) + s.Require().NoError(err, "GetUserDashboardStats") + s.Require().Equal(int64(1), stats.TotalApiKeys) + s.Require().Equal(int64(1), stats.TotalRequests) +} + +// --- GetAccountTodayStats --- + +func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID) + s.Require().NoError(err, "GetAccountTodayStats") + s.Require().Equal(int64(1), stats.Requests) + s.Require().Equal(int64(30), stats.Tokens) +} + +// --- GetBatchUserUsageStats --- + +func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"}) + + s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) + s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) + + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + s.Require().NoError(err, "GetBatchUserUsageStats") + s.Require().Len(stats, 2) + s.Require().NotNil(stats[user1.ID]) + s.Require().NotNil(stats[user2.ID]) +} + +func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + s.Require().NoError(err) + s.Require().Empty(stats) +} + +// --- GetBatchApiKeyUsageStats --- + +func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"}) + + s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) + s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) + + stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + s.Require().NoError(err, "GetBatchApiKeyUsageStats") + s.Require().Len(stats, 2) +} + +func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { + stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{}) + s.Require().NoError(err) + s.Require().Empty(stats) +} + +// --- GetGlobalStats --- + +func (s *UsageLogRepoSuite) TestGetGlobalStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + + stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour)) + s.Require().NoError(err, "GetGlobalStats") + s.Require().Equal(int64(2), stats.TotalRequests) + s.Require().Equal(int64(25), stats.TotalInputTokens) + s.Require().Equal(int64(45), stats.TotalOutputTokens) +} + +func maxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} + +// --- ListByUserAndTimeRange --- + +func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime) + s.Require().NoError(err, "ListByUserAndTimeRange") + s.Require().Len(logs, 2) +} + +// --- ListByApiKeyAndTimeRange --- + +func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime) + s.Require().NoError(err, "ListByApiKeyAndTimeRange") + s.Require().Len(logs, 2) +} + +// --- ListByAccountAndTimeRange --- + +func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime) + s.Require().NoError(err, "ListByAccountAndTimeRange") + s.Require().Len(logs, 2) +} + +// --- ListByModelAndTimeRange --- + +func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + + // Create logs with different models + log1 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-opus", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: base, + } + s.Require().NoError(s.repo.Create(s.ctx, log1)) + + log2 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-opus", + InputTokens: 15, + OutputTokens: 25, + TotalCost: 0.6, + ActualCost: 0.6, + CreatedAt: base.Add(30 * time.Minute), + } + s.Require().NoError(s.repo.Create(s.ctx, log2)) + + log3 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-sonnet", + InputTokens: 20, + OutputTokens: 30, + TotalCost: 0.7, + ActualCost: 0.7, + CreatedAt: base.Add(1 * time.Hour), + } + s.Require().NoError(s.repo.Create(s.ctx, log3)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime) + s.Require().NoError(err, "ListByModelAndTimeRange") + s.Require().Len(logs, 2) +} + +// --- GetAccountWindowStats --- + +func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"}) + + now := time.Now() + windowStart := now.Add(-10 * time.Minute) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute)) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window + + stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart) + s.Require().NoError(err, "GetAccountWindowStats") + s.Require().Equal(int64(2), stats.Requests) + s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25) +} + +// --- GetUserUsageTrendByUserID --- + +func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(48 * time.Hour) + trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day") + s.Require().NoError(err, "GetUserUsageTrendByUserID") + s.Require().Len(trend, 2) // 2 different days +} + +func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(3 * time.Hour) + trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour") + s.Require().NoError(err, "GetUserUsageTrendByUserID hourly") + s.Require().Len(trend, 3) // 3 different hours +} + +// --- GetUserModelStats --- + +func (s *UsageLogRepoSuite) TestGetUserModelStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + + // Create logs with different models + log1 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-opus", + InputTokens: 100, + OutputTokens: 200, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: base, + } + s.Require().NoError(s.repo.Create(s.ctx, log1)) + + log2 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-sonnet", + InputTokens: 50, + OutputTokens: 100, + TotalCost: 0.2, + ActualCost: 0.2, + CreatedAt: base.Add(1 * time.Hour), + } + s.Require().NoError(s.repo.Create(s.ctx, log2)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime) + s.Require().NoError(err, "GetUserModelStats") + s.Require().Len(stats, 2) + + // Should be ordered by total_tokens DESC + s.Require().Equal("claude-3-opus", stats[0].Model) + s.Require().Equal(int64(300), stats[0].TotalTokens) +} + +// --- GetUsageTrendWithFilters --- + +func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(48 * time.Hour) + + // Test with user filter + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0) + s.Require().NoError(err, "GetUsageTrendWithFilters user filter") + s.Require().Len(trend, 2) + + // Test with apiKey filter + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID) + s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") + s.Require().Len(trend, 2) + + // Test with both filters + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID) + s.Require().NoError(err, "GetUsageTrendWithFilters both filters") + s.Require().Len(trend, 2) +} + +func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(3 * time.Hour) + + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0) + s.Require().NoError(err, "GetUsageTrendWithFilters hourly") + s.Require().Len(trend, 2) +} + +// --- GetModelStatsWithFilters --- + +func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + + log1 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-opus", + InputTokens: 100, + OutputTokens: 200, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: base, + } + s.Require().NoError(s.repo.Create(s.ctx, log1)) + + log2 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-sonnet", + InputTokens: 50, + OutputTokens: 100, + TotalCost: 0.2, + ActualCost: 0.2, + CreatedAt: base.Add(1 * time.Hour), + } + s.Require().NoError(s.repo.Create(s.ctx, log2)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + + // Test with user filter + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0) + s.Require().NoError(err, "GetModelStatsWithFilters user filter") + s.Require().Len(stats, 2) + + // Test with apiKey filter + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0) + s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") + s.Require().Len(stats, 2) + + // Test with account filter + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID) + s.Require().NoError(err, "GetModelStatsWithFilters account filter") + s.Require().Len(stats, 2) +} + +// --- GetAccountUsageStats --- + +func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"}) + + base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) + + // Create logs on different days + log1 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-opus", + InputTokens: 100, + OutputTokens: 200, + TotalCost: 0.5, + ActualCost: 0.4, + CreatedAt: base.Add(12 * time.Hour), + } + s.Require().NoError(s.repo.Create(s.ctx, log1)) + + log2 := &model.UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + Model: "claude-3-sonnet", + InputTokens: 50, + OutputTokens: 100, + TotalCost: 0.2, + ActualCost: 0.15, + CreatedAt: base.Add(36 * time.Hour), // next day + } + s.Require().NoError(s.repo.Create(s.ctx, log2)) + + startTime := base + endTime := base.Add(72 * time.Hour) + + resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime) + s.Require().NoError(err, "GetAccountUsageStats") + + s.Require().Len(resp.History, 2, "expected 2 days of history") + s.Require().Equal(int64(2), resp.Summary.TotalRequests) + s.Require().Equal(int64(450), resp.Summary.TotalTokens) + s.Require().Len(resp.Models, 2) +} + +func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"}) + + base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) + startTime := base + endTime := base.Add(72 * time.Hour) + + resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime) + s.Require().NoError(err, "GetAccountUsageStats empty") + + s.Require().Len(resp.History, 0) + s.Require().Equal(int64(0), resp.Summary.TotalRequests) +} + +// --- GetUserUsageTrend --- + +func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base) + s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base) + s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(48 * time.Hour) + + trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10) + s.Require().NoError(err, "GetUserUsageTrend") + s.Require().GreaterOrEqual(len(trend), 2) +} + +// --- GetApiKeyUsageTrend --- + +func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) + s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base) + s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(48 * time.Hour) + + trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) + s.Require().NoError(err, "GetApiKeyUsageTrend") + s.Require().GreaterOrEqual(len(trend), 2) +} + +func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) + s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(3 * time.Hour) + + trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) + s.Require().NoError(err, "GetApiKeyUsageTrend hourly") + s.Require().Len(trend, 2) +} + +// --- ListWithFilters (additional filter tests) --- + +func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"}) + + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) + + filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID} + logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) + s.Require().NoError(err, "ListWithFilters apiKey") + s.Require().Len(logs, 1) + s.Require().Equal(int64(1), page.Total) +} + +func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime} + logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) + s.Require().NoError(err, "ListWithFilters time range") + s.Require().Len(logs, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"}) + + base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) + s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) + s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour)) + + startTime := base.Add(-1 * time.Hour) + endTime := base.Add(2 * time.Hour) + filters := usagestats.UsageLogFilters{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + StartTime: &startTime, + EndTime: &endTime, + } + logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters) + s.Require().NoError(err, "ListWithFilters combined") + s.Require().Len(logs, 2) + s.Require().Equal(int64(2), page.Total) +} diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go new file mode 100644 index 00000000..7efe2d5c --- /dev/null +++ b/backend/internal/repository/user_repo_integration_test.go @@ -0,0 +1,448 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/lib/pq" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type UserRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *UserRepository +} + +func (s *UserRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewUserRepository(s.db) +} + +func TestUserRepoSuite(t *testing.T) { + suite.Run(t, new(UserRepoSuite)) +} + +// --- Create / GetByID / GetByEmail / Update / Delete --- + +func (s *UserRepoSuite) TestCreate() { + user := &model.User{ + Email: "create@test.com", + Username: "testuser", + Role: model.RoleUser, + Status: model.StatusActive, + } + + err := s.repo.Create(s.ctx, user) + s.Require().NoError(err, "Create") + s.Require().NotZero(user.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal("create@test.com", got.Email) +} + +func (s *UserRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *UserRepoSuite) TestGetByEmail() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"}) + + got, err := s.repo.GetByEmail(s.ctx, user.Email) + s.Require().NoError(err, "GetByEmail") + s.Require().Equal(user.ID, got.ID) +} + +func (s *UserRepoSuite) TestGetByEmail_NotFound() { + _, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com") + s.Require().Error(err, "expected error for non-existent email") +} + +func (s *UserRepoSuite) TestUpdate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"}) + + user.Username = "updated" + err := s.repo.Update(s.ctx, user) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("updated", got.Username) +} + +func (s *UserRepoSuite) TestDelete() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) + + err := s.repo.Delete(s.ctx, user.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, user.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- List / ListWithFilters --- + +func (s *UserRepoSuite) TestList() { + mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"}) + mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"}) + + users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "List") + s.Require().Len(users, 2) + s.Require().Equal(int64(2), page.Total) +} + +func (s *UserRepoSuite) TestListWithFilters_Status() { + mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive}) + mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "") + s.Require().NoError(err) + s.Require().Len(users, 1) + s.Require().Equal(model.StatusActive, users[0].Status) +} + +func (s *UserRepoSuite) TestListWithFilters_Role() { + mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser}) + mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "") + s.Require().NoError(err) + s.Require().Len(users, 1) + s.Require().Equal(model.RoleAdmin, users[0].Role) +} + +func (s *UserRepoSuite) TestListWithFilters_Search() { + mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"}) + mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") + s.Require().NoError(err) + s.Require().Len(users, 1) + s.Require().Contains(users[0].Email, "alice") +} + +func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { + mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"}) + mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") + s.Require().NoError(err) + s.Require().Len(users, 1) + s.Require().Equal("JohnDoe", users[0].Username) +} + +func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { + mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"}) + mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello") + s.Require().NoError(err) + s.Require().Len(users, 1) + s.Require().Equal("wx_hello", users[0].Wechat) +} + +func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"}) + + _ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + _ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-1 * time.Hour), + }) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@") + s.Require().NoError(err, "ListWithFilters") + s.Require().Len(users, 1, "expected 1 user") + s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription") + s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload") + s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch") +} + +func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { + mustCreateUser(s.T(), s.db, &model.User{ + Email: "a@example.com", + Username: "Alice", + Wechat: "wx_a", + Role: model.RoleUser, + Status: model.StatusActive, + Balance: 10, + }) + target := mustCreateUser(s.T(), s.db, &model.User{ + Email: "b@example.com", + Username: "Bob", + Wechat: "wx_b", + Role: model.RoleAdmin, + Status: model.StatusActive, + Balance: 1, + }) + mustCreateUser(s.T(), s.db, &model.User{ + Email: "c@example.com", + Role: model.RoleAdmin, + Status: model.StatusDisabled, + }) + + users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@") + s.Require().NoError(err, "ListWithFilters") + s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") + s.Require().Len(users, 1, "ListWithFilters len mismatch") + s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch") +} + +// --- Balance operations --- + +func (s *UserRepoSuite) TestUpdateBalance() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10}) + + err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5) + s.Require().NoError(err, "UpdateBalance") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal(12.5, got.Balance) +} + +func (s *UserRepoSuite) TestUpdateBalance_Negative() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10}) + + err := s.repo.UpdateBalance(s.ctx, user.ID, -3) + s.Require().NoError(err, "UpdateBalance with negative") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal(7.0, got.Balance) +} + +func (s *UserRepoSuite) TestDeductBalance() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10}) + + err := s.repo.DeductBalance(s.ctx, user.ID, 5) + s.Require().NoError(err, "DeductBalance") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal(5.0, got.Balance) +} + +func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5}) + + err := s.repo.DeductBalance(s.ctx, user.ID, 999) + s.Require().Error(err, "expected error for insufficient balance") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound) +} + +func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10}) + + err := s.repo.DeductBalance(s.ctx, user.ID, 10) + s.Require().NoError(err, "DeductBalance exact amount") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Zero(got.Balance) +} + +// --- Concurrency --- + +func (s *UserRepoSuite) TestUpdateConcurrency() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5}) + + err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3) + s.Require().NoError(err, "UpdateConcurrency") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal(8, got.Concurrency) +} + +func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5}) + + err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2) + s.Require().NoError(err, "UpdateConcurrency negative") + + got, err := s.repo.GetByID(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Equal(3, got.Concurrency) +} + +// --- ExistsByEmail --- + +func (s *UserRepoSuite) TestExistsByEmail() { + mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) + + exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com") + s.Require().NoError(err, "ExistsByEmail") + s.Require().True(exists) + + notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com") + s.Require().NoError(err) + s.Require().False(notExists) +} + +// --- RemoveGroupFromAllowedGroups --- + +func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { + groupID := int64(42) + userA := mustCreateUser(s.T(), s.db, &model.User{ + Email: "a1@example.com", + AllowedGroups: pq.Int64Array{groupID, 7}, + }) + mustCreateUser(s.T(), s.db, &model.User{ + Email: "a2@example.com", + AllowedGroups: pq.Int64Array{7}, + }) + + affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID) + s.Require().NoError(err, "RemoveGroupFromAllowedGroups") + s.Require().Equal(int64(1), affected, "expected 1 affected row") + + got, err := s.repo.GetByID(s.ctx, userA.ID) + s.Require().NoError(err, "GetByID") + for _, id := range got.AllowedGroups { + s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups") + } +} + +func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { + mustCreateUser(s.T(), s.db, &model.User{ + Email: "nomatch@test.com", + AllowedGroups: pq.Int64Array{1, 2, 3}, + }) + + affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999) + s.Require().NoError(err) + s.Require().Zero(affected, "expected no affected rows") +} + +// --- GetFirstAdmin --- + +func (s *UserRepoSuite) TestGetFirstAdmin() { + admin1 := mustCreateUser(s.T(), s.db, &model.User{ + Email: "admin1@example.com", + Role: model.RoleAdmin, + Status: model.StatusActive, + }) + mustCreateUser(s.T(), s.db, &model.User{ + Email: "admin2@example.com", + Role: model.RoleAdmin, + Status: model.StatusActive, + }) + + got, err := s.repo.GetFirstAdmin(s.ctx) + s.Require().NoError(err, "GetFirstAdmin") + s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch") +} + +func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { + mustCreateUser(s.T(), s.db, &model.User{ + Email: "user@example.com", + Role: model.RoleUser, + Status: model.StatusActive, + }) + + _, err := s.repo.GetFirstAdmin(s.ctx) + s.Require().Error(err, "expected error when no admin exists") +} + +func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { + mustCreateUser(s.T(), s.db, &model.User{ + Email: "disabled@example.com", + Role: model.RoleAdmin, + Status: model.StatusDisabled, + }) + activeAdmin := mustCreateUser(s.T(), s.db, &model.User{ + Email: "active@example.com", + Role: model.RoleAdmin, + Status: model.StatusActive, + }) + + got, err := s.repo.GetFirstAdmin(s.ctx) + s.Require().NoError(err, "GetFirstAdmin") + s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin") +} + +// --- Combined original test --- + +func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { + user1 := mustCreateUser(s.T(), s.db, &model.User{ + Email: "a@example.com", + Username: "Alice", + Wechat: "wx_a", + Role: model.RoleUser, + Status: model.StatusActive, + Balance: 10, + }) + user2 := mustCreateUser(s.T(), s.db, &model.User{ + Email: "b@example.com", + Username: "Bob", + Wechat: "wx_b", + Role: model.RoleAdmin, + Status: model.StatusActive, + Balance: 1, + }) + _ = mustCreateUser(s.T(), s.db, &model.User{ + Email: "c@example.com", + Role: model.RoleAdmin, + Status: model.StatusDisabled, + }) + + got, err := s.repo.GetByID(s.ctx, user1.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch") + + gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email) + s.Require().NoError(err, "GetByEmail") + s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch") + + got.Username = "Alice2" + s.Require().NoError(s.repo.Update(s.ctx, got), "Update") + got2, err := s.repo.GetByID(s.ctx, user1.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("Alice2", got2.Username, "Update did not persist") + + s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance") + got3, err := s.repo.GetByID(s.ctx, user1.ID) + s.Require().NoError(err, "GetByID after UpdateBalance") + s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch") + + s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance") + got4, err := s.repo.GetByID(s.ctx, user1.ID) + s.Require().NoError(err, "GetByID after DeductBalance") + s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch") + + err = s.repo.DeductBalance(s.ctx, user1.ID, 999) + s.Require().Error(err, "DeductBalance expected error for insufficient balance") + s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error") + + s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") + got5, err := s.repo.GetByID(s.ctx, user1.ID) + s.Require().NoError(err, "GetByID after UpdateConcurrency") + s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch") + + params := pagination.PaginationParams{Page: 1, PageSize: 10} + users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@") + s.Require().NoError(err, "ListWithFilters") + s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") + s.Require().Len(users, 1, "ListWithFilters len mismatch") + s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch") +} diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go new file mode 100644 index 00000000..9cecf4e8 --- /dev/null +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -0,0 +1,733 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/suite" + "gorm.io/gorm" +) + +type UserSubscriptionRepoSuite struct { + suite.Suite + ctx context.Context + db *gorm.DB + repo *UserSubscriptionRepository +} + +func (s *UserSubscriptionRepoSuite) SetupTest() { + s.ctx = context.Background() + s.db = testTx(s.T()) + s.repo = NewUserSubscriptionRepository(s.db) +} + +func TestUserSubscriptionRepoSuite(t *testing.T) { + suite.Run(t, new(UserSubscriptionRepoSuite)) +} + +// --- Create / GetByID / Update / Delete --- + +func (s *UserSubscriptionRepoSuite) TestCreate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"}) + + sub := &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + err := s.repo.Create(s.ctx, sub) + s.Require().NoError(err, "Create") + s.Require().NotZero(sub.ID, "expected ID to be set") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(sub.UserID, got.UserID) + s.Require().Equal(sub.GroupID, got.GroupID) +} + +func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) + admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) + + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + AssignedBy: &admin.ID, + }) + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err, "GetByID") + s.Require().NotNil(got.User, "expected User preload") + s.Require().NotNil(got.Group, "expected Group preload") + s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload") + s.Require().Equal(user.ID, got.User.ID) + s.Require().Equal(group.ID, got.Group.ID) + s.Require().Equal(admin.ID, got.AssignedByUser.ID) +} + +func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() { + _, err := s.repo.GetByID(s.ctx, 999999) + s.Require().Error(err, "expected error for non-existent ID") +} + +func (s *UserSubscriptionRepoSuite) TestUpdate() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + sub.Notes = "updated notes" + err := s.repo.Update(s.ctx, sub) + s.Require().NoError(err, "Update") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err, "GetByID after update") + s.Require().Equal("updated notes", got.Notes) +} + +func (s *UserSubscriptionRepoSuite) TestDelete() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + err := s.repo.Delete(s.ctx, sub.ID) + s.Require().NoError(err, "Delete") + + _, err = s.repo.GetByID(s.ctx, sub.ID) + s.Require().Error(err, "expected error after delete") +} + +// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID --- + +func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID) + s.Require().NoError(err, "GetByUserIDAndGroupID") + s.Require().Equal(sub.ID, got.ID) + s.Require().NotNil(got.Group, "expected Group preload") +} + +func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() { + _, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999) + s.Require().Error(err, "expected error for non-existent pair") +} + +func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"}) + + // Create active subscription (future expiry) + active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(2 * time.Hour), + }) + + got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID) + s.Require().NoError(err, "GetActiveByUserIDAndGroupID") + s.Require().Equal(active.ID, got.ID) +} + +func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"}) + + // Create expired subscription (past expiry but active status) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(-2 * time.Hour), + }) + + _, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID) + s.Require().Error(err, "expected error for expired subscription") +} + +// --- ListByUserID / ListActiveByUserID --- + +func (s *UserSubscriptionRepoSuite) TestListByUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"}) + g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g1.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g2.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + subs, err := s.repo.ListByUserID(s.ctx, user.ID) + s.Require().NoError(err, "ListByUserID") + s.Require().Len(subs, 2) + for _, sub := range subs { + s.Require().NotNil(sub.Group, "expected Group preload") + } +} + +func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"}) + g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g1.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g2.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID) + s.Require().NoError(err, "ListActiveByUserID") + s.Require().Len(subs, 1) + s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status) +} + +// --- ListByGroupID --- + +func (s *UserSubscriptionRepoSuite) TestListByGroupID() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user1.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user2.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err, "ListByGroupID") + s.Require().Len(subs, 2) + s.Require().Equal(int64(2), page.Total) + for _, sub := range subs { + s.Require().NotNil(sub.User, "expected User preload") + s.Require().NotNil(sub.Group, "expected Group preload") + } +} + +// --- List with filters --- + +func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "") + s.Require().NoError(err, "List") + s.Require().Len(subs, 1) + s.Require().Equal(int64(1), page.Total) +} + +func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user1.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user2.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "") + s.Require().NoError(err) + s.Require().Len(subs, 1) + s.Require().Equal(user1.ID, subs[0].UserID) +} + +func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"}) + g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g1.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: g2.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "") + s.Require().NoError(err) + s.Require().Len(subs, 1) + s.Require().Equal(g1.ID, subs[0].GroupID) +} + +func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired) + s.Require().NoError(err) + s.Require().Len(subs, 1) + s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status) +} + +// --- Usage tracking --- + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25) + s.Require().NoError(err, "IncrementUsage") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Equal(1.25, got.DailyUsageUSD) + s.Require().Equal(1.25, got.WeeklyUsageUSD) + s.Require().Equal(1.25, got.MonthlyUsageUSD) +} + +func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)) + s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5)) + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Equal(3.5, got.DailyUsageUSD) +} + +func (s *UserSubscriptionRepoSuite) TestActivateWindows() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt) + s.Require().NoError(err, "ActivateWindows") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().NotNil(got.DailyWindowStart) + s.Require().NotNil(got.WeeklyWindowStart) + s.Require().NotNil(got.MonthlyWindowStart) + s.Require().True(got.DailyWindowStart.Equal(activateAt)) +} + +func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + DailyUsageUSD: 10.0, + WeeklyUsageUSD: 20.0, + }) + + resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC) + err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt) + s.Require().NoError(err, "ResetDailyUsage") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Zero(got.DailyUsageUSD) + s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged") + s.Require().True(got.DailyWindowStart.Equal(resetAt)) +} + +func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + WeeklyUsageUSD: 15.0, + MonthlyUsageUSD: 30.0, + }) + + resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC) + err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt) + s.Require().NoError(err, "ResetWeeklyUsage") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Zero(got.WeeklyUsageUSD) + s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged") + s.Require().True(got.WeeklyWindowStart.Equal(resetAt)) +} + +func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + MonthlyUsageUSD: 100.0, + }) + + resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC) + err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt) + s.Require().NoError(err, "ResetMonthlyUsage") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Zero(got.MonthlyUsageUSD) + s.Require().True(got.MonthlyWindowStart.Equal(resetAt)) +} + +// --- UpdateStatus / ExtendExpiry / UpdateNotes --- + +func (s *UserSubscriptionRepoSuite) TestUpdateStatus() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired) + s.Require().NoError(err, "UpdateStatus") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Equal(model.SubscriptionStatusExpired, got.Status) +} + +func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry) + s.Require().NoError(err, "ExtendExpiry") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().True(got.ExpiresAt.Equal(newExpiry)) +} + +func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"}) + sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user") + s.Require().NoError(err, "UpdateNotes") + + got, err := s.repo.GetByID(s.ctx, sub.ID) + s.Require().NoError(err) + s.Require().Equal("VIP user", got.Notes) +} + +// --- ListExpired / BatchUpdateExpiredStatus --- + +func (s *UserSubscriptionRepoSuite) TestListExpired() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + expired, err := s.repo.ListExpired(s.ctx) + s.Require().NoError(err, "ListExpired") + s.Require().Len(expired, 1) +} + +func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"}) + + active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx) + s.Require().NoError(err, "BatchUpdateExpiredStatus") + s.Require().Equal(int64(1), affected) + + gotActive, _ := s.repo.GetByID(s.ctx, active.ID) + s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status) + + gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID) + s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status) +} + +// --- ExistsByUserIDAndGroupID --- + +func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + + exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID) + s.Require().NoError(err, "ExistsByUserIDAndGroupID") + s.Require().True(exists) + + notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999) + s.Require().NoError(err) + s.Require().False(notExists) +} + +// --- CountByGroupID / CountActiveByGroupID --- + +func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user1.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user2.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + count, err := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "CountByGroupID") + s.Require().Equal(int64(2), count) +} + +func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { + user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user1.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user2.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time + }) + + count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "CountActiveByGroupID") + s.Require().Equal(int64(1), count, "only future expiry counts as active") +} + +// --- DeleteByGroupID --- + +func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"}) + + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + }) + mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusExpired, + ExpiresAt: time.Now().Add(-24 * time.Hour), + }) + + affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID) + s.Require().NoError(err, "DeleteByGroupID") + s.Require().Equal(int64(2), affected) + + count, _ := s.repo.CountByGroupID(s.ctx, group.ID) + s.Require().Zero(count) +} + +// --- Combined original test --- + +func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() { + user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"}) + group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"}) + + active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(2 * time.Hour), + }) + expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + UserID: user.ID, + GroupID: group.ID, + Status: model.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(-2 * time.Hour), + }) + + got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID) + s.Require().NoError(err, "GetActiveByUserIDAndGroupID") + s.Require().Equal(active.ID, got.ID, "expected active subscription") + + activateAt := time.Now().Add(-25 * time.Hour) + s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows") + s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage") + + after, err := s.repo.GetByID(s.ctx, active.ID) + s.Require().NoError(err, "GetByID") + s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch") + s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch") + s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch") + s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated") + s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated") + s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated") + + resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision + s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage") + afterReset, err := s.repo.GetByID(s.ctx, active.ID) + s.Require().NoError(err, "GetByID after reset") + s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0") + s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil") + s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated") + + affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx) + s.Require().NoError(err, "BatchUpdateExpiredStatus") + s.Require().Equal(int64(1), affected, "expected 1 affected row") + updated, err := s.repo.GetByID(s.ctx, expiredActive.ID) + s.Require().NoError(err, "GetByID expired") + s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired") +}