merge: 合并 upstream/main 解决 PR #37 冲突

- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
This commit is contained in:
IanShaw027
2025-12-26 21:56:08 +08:00
118 changed files with 6077 additions and 3478 deletions

View File

@@ -23,6 +23,8 @@ linters:
desc: "service must not import repository" desc: "service must not import repository"
- pkg: gorm.io/gorm - pkg: gorm.io/gorm
desc: "service must not import gorm" desc: "service must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "service must not import redis"
handler-no-repository: handler-no-repository:
list-mode: original list-mode: original
files: files:
@@ -30,6 +32,10 @@ linters:
deny: deny:
- pkg: github.com/Wei-Shaw/sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository" desc: "handler must not import repository"
- pkg: gorm.io/gorm
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default. # Such cases aren't reported by default.

View File

@@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db) apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db) groupRepository := repository.NewGroupRepository(db)

View File

@@ -22,12 +22,14 @@ require (
golang.org/x/net v0.47.0 golang.org/x/net v0.47.0
golang.org/x/term v0.37.0 golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4 gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5 gorm.io/gorm v1.25.5
) )
require ( require (
dario.cat/mergo v1.0.2 // indirect dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect
@@ -57,6 +59,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect github.com/google/subcommands v1.2.0 // indirect
@@ -64,8 +67,8 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
@@ -132,4 +135,5 @@ require (
google.golang.org/grpc v1.75.1 // indirect google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
) )

View File

@@ -1,5 +1,7 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
@@ -77,10 +79,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -104,10 +113,10 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk= github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -135,8 +144,12 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
@@ -319,8 +332,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=

View File

@@ -4,7 +4,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -107,7 +107,7 @@ type BulkUpdateAccountsRequest struct {
// AccountWithConcurrency extends Account with real-time concurrency info // AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct { type AccountWithConcurrency struct {
*model.Account *dto.Account
CurrentConcurrency int `json:"current_concurrency"` CurrentConcurrency int `json:"current_concurrency"`
} }
@@ -142,7 +142,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts)) result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts { for i := range accounts {
result[i] = AccountWithConcurrency{ result[i] = AccountWithConcurrency{
Account: &accounts[i], Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID], CurrentConcurrency: concurrencyCounts[accounts[i].ID],
} }
} }
@@ -165,7 +165,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Create handles creating a new account // Create handles creating a new account
@@ -193,7 +193,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Update handles updating an account // Update handles updating an account
@@ -227,7 +227,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// Delete handles deleting an account // Delete handles deleting an account
@@ -398,7 +398,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
return return
} }
response.Success(c, updatedAccount) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// GetStats handles getting account statistics // GetStats handles getting account statistics
@@ -447,7 +447,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// BatchCreate handles batch creating accounts // BatchCreate handles batch creating accounts
@@ -823,7 +823,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }
// GetAvailableModels handles getting available models for an account // GetAvailableModels handles getting available models for an account

View File

@@ -1,11 +1,12 @@
package admin package admin
import ( import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -3,7 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, groups, total, page, pageSize) outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
} }
// GetAll handles getting all active groups without pagination // GetAll handles getting all active groups without pagination
@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
func (h *GroupHandler) GetAll(c *gin.Context) { func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform") platform := c.Query("platform")
var groups []model.Group var groups []service.Group
var err error var err error
if platform != "" { if platform != "" {
@@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return return
} }
response.Success(c, groups) outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Success(c, outGroups)
} }
// GetByID handles getting a group by ID // GetByID handles getting a group by ID
@@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Create handles creating a new group // Create handles creating a new group
@@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Update handles updating a group // Update handles updating a group
@@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, group) response.Success(c, dto.GroupFromService(group))
} }
// Delete handles deleting a group // Delete handles deleting a group
@@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return return
} }
response.Paginated(c, keys, total, page, pageSize) outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
} }

View File

@@ -3,6 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -163,7 +164,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return return
} }
response.Success(c, updatedAccount) response.Success(c, dto.AccountFromService(updatedAccount))
} }
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info // CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
@@ -224,5 +225,5 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
return return
} }
response.Success(c, account) response.Success(c, dto.AccountFromService(account))
} }

View File

@@ -4,6 +4,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -57,7 +58,11 @@ func (h *ProxyHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, proxies, total, page, pageSize) out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
} }
// GetAll handles getting all active proxies without pagination // GetAll handles getting all active proxies without pagination
@@ -72,7 +77,11 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, proxies) out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Success(c, out)
return return
} }
@@ -82,7 +91,11 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return return
} }
response.Success(c, proxies) out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Success(c, out)
} }
// GetByID handles getting a proxy by ID // GetByID handles getting a proxy by ID
@@ -100,7 +113,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, proxy) response.Success(c, dto.ProxyFromService(proxy))
} }
// Create handles creating a new proxy // Create handles creating a new proxy
@@ -125,7 +138,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, proxy) response.Success(c, dto.ProxyFromService(proxy))
} }
// Update handles updating a proxy // Update handles updating a proxy
@@ -157,7 +170,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, proxy) response.Success(c, dto.ProxyFromService(proxy))
} }
// Delete handles deleting a proxy // Delete handles deleting a proxy
@@ -233,7 +246,11 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return return
} }
response.Paginated(c, accounts, total, page, pageSize) out := make([]dto.Account, 0, len(accounts))
for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i]))
}
response.Paginated(c, out, total, page, pageSize)
} }
// BatchCreateProxyItem represents a single proxy in batch create request // BatchCreateProxyItem represents a single proxy in batch create request

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -47,7 +48,11 @@ func (h *RedeemHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, codes, total, page, pageSize) out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Paginated(c, out, total, page, pageSize)
} }
// GetByID handles getting a redeem code by ID // GetByID handles getting a redeem code by ID
@@ -65,7 +70,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, code) response.Success(c, dto.RedeemCodeFromService(code))
} }
// Generate handles generating new redeem codes // Generate handles generating new redeem codes
@@ -89,7 +94,11 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return return
} }
response.Success(c, codes) out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
} }
// Delete handles deleting a redeem code // Delete handles deleting a redeem code
@@ -148,7 +157,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
return return
} }
response.Success(c, code) response.Success(c, dto.RedeemCodeFromService(code))
} }
// GetStats handles getting redeem code statistics // GetStats handles getting redeem code statistics

View File

@@ -1,7 +1,7 @@
package admin package admin
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
return return
} }
response.Success(c, settings) response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
} }
// UpdateSettingsRequest 更新设置请求 // UpdateSettingsRequest 更新设置请求
@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587 req.SmtpPort = 587
} }
settings := &model.SystemSettings{ settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled, RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost, SmtpHost: req.SmtpHost,
@@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return return
} }
response.Success(c, updatedSettings) response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
} }
// TestSmtpRequest 测试SMTP连接请求 // TestSmtpRequest 测试SMTP连接请求

View File

@@ -3,9 +3,10 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
return return
} }
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
} }
// GetByID handles getting a subscription by ID // GetByID handles getting a subscription by ID
@@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// GetProgress handles getting subscription usage progress // GetProgress handles getting subscription usage progress
@@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// BulkAssign handles bulk assigning subscriptions to multiple users // BulkAssign handles bulk assigning subscriptions to multiple users
@@ -172,7 +177,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
return return
} }
response.Success(c, result) response.Success(c, dto.BulkAssignResultFromService(result))
} }
// Extend handles extending a subscription // Extend handles extending a subscription
@@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return return
} }
response.Success(c, subscription) response.Success(c, dto.UserSubscriptionFromService(subscription))
} }
// Revoke handles revoking a subscription // Revoke handles revoking a subscription
@@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return return
} }
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
} }
// ListByUser handles listing subscriptions for a specific user // ListByUser handles listing subscriptions for a specific user
@@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// Helper function to get admin ID from context // Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 { func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists { subject, ok := middleware2.GetAuthSubjectFromContext(c)
if u, ok := user.(*model.User); ok && u != nil { if !ok {
return u.ID
}
}
return 0 return 0
} }
return subject.UserID
}

View File

@@ -4,6 +4,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -94,7 +95,11 @@ func (h *UsageHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, records, result.Total, page, pageSize) out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// Stats handles getting usage statistics with filters // Stats handles getting usage statistics with filters

View File

@@ -3,6 +3,7 @@ package admin
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -68,7 +69,11 @@ func (h *UserHandler) List(c *gin.Context) {
return return
} }
response.Paginated(c, users, total, page, pageSize) out := make([]dto.User, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromService(&users[i]))
}
response.Paginated(c, out, total, page, pageSize)
} }
// GetByID handles getting a user by ID // GetByID handles getting a user by ID
@@ -86,7 +91,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }
// Create handles creating a new user // Create handles creating a new user
@@ -113,7 +118,7 @@ func (h *UserHandler) Create(c *gin.Context) {
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }
// Update handles updating a user // Update handles updating a user
@@ -148,7 +153,7 @@ func (h *UserHandler) Update(c *gin.Context) {
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }
// Delete handles deleting a user // Delete handles deleting a user
@@ -190,7 +195,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }
// GetUserAPIKeys handles getting user's API keys // GetUserAPIKeys handles getting user's API keys
@@ -210,7 +215,11 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
return return
} }
response.Paginated(c, keys, total, page, pageSize) out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
} }
// GetUserUsage handles getting user's usage statistics // GetUserUsage handles getting user's usage statistics

View File

@@ -3,9 +3,10 @@ package handler
import ( import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
// List handles listing user's API keys with pagination // List handles listing user's API keys with pagination
// GET /api/v1/api-keys // GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) { func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Paginated(c, keys, result.Total, page, pageSize) out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// GetByID handles getting a single API key // GetByID handles getting a single API key
// GET /api/v1/api-keys/:id // GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) { func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
} }
// 验证所有权 // 验证所有权
if key.UserID != user.ID { if key.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this key") response.Forbidden(c, "Not authorized to access this key")
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Create handles creating a new API key // Create handles creating a new API key
// POST /api/v1/api-keys // POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) { func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
GroupID: req.GroupID, GroupID: req.GroupID,
CustomKey: req.CustomKey, CustomKey: req.CustomKey,
} }
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Update handles updating an API key // Update handles updating an API key
// PUT /api/v1/api-keys/:id // PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) { func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq.Status = &req.Status svcReq.Status = &req.Status
} }
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, key) response.Success(c, dto.ApiKeyFromService(key))
} }
// Delete handles deleting an API key // Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id // DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) { func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
return return
} }
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
// GetAvailableGroups 获取用户可以绑定的分组列表 // GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available // GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not authenticated") response.Unauthorized(c, "User not authenticated")
return return
} }
user, ok := userValue.(*model.User) groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, groups) out := make([]dto.Group, 0, len(groups))
for i := range groups {
out = append(out, *dto.GroupFromService(&groups[i]))
}
response.Success(c, out)
} }

View File

@@ -1,8 +1,9 @@
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -11,12 +12,14 @@ import (
// AuthHandler handles authentication-related requests // AuthHandler handles authentication-related requests
type AuthHandler struct { type AuthHandler struct {
authService *service.AuthService authService *service.AuthService
userService *service.UserService
} }
// NewAuthHandler creates a new AuthHandler // NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler { func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
authService: authService, authService: authService,
userService: userService,
} }
} }
@@ -51,7 +54,7 @@ type LoginRequest struct {
type AuthResponse struct { type AuthResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
User *model.User `json:"user"` User *dto.User `json:"user"`
} }
// Register handles user registration // Register handles user registration
@@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
response.Success(c, AuthResponse{ response.Success(c, AuthResponse{
AccessToken: token, AccessToken: token,
TokenType: "Bearer", TokenType: "Bearer",
User: user, User: dto.UserFromService(user),
}) })
} }
@@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.Success(c, AuthResponse{ response.Success(c, AuthResponse{
AccessToken: token, AccessToken: token,
TokenType: "Bearer", TokenType: "Bearer",
User: user, User: dto.UserFromService(user),
}) })
} }
// GetCurrentUser handles getting current authenticated user // GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me // GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) { func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not authenticated") response.Unauthorized(c, "User not authenticated")
return return
} }
user, ok := userValue.(*model.User) user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if !ok { if err != nil {
response.InternalError(c, "Invalid user context") response.ErrorFrom(c, err)
return return
} }
response.Success(c, user) response.Success(c, dto.UserFromService(user))
} }

View File

@@ -0,0 +1,310 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func UserFromService(u *service.User) *User {
if u == nil {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
for i := range u.Subscriptions {
s := u.Subscriptions[i]
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
}
}
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
func AccountFromService(a *service.Account) *Account {
if a == nil {
return nil
}
out := AccountFromServiceShallow(a)
out.Proxy = ProxyFromService(a.Proxy)
if len(a.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
for i := range a.AccountGroups {
ag := a.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
if len(a.Groups) > 0 {
out.Groups = make([]*Group, 0, len(a.Groups))
for _, g := range a.Groups {
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
}
}
return out
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil
}
return &AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Account: AccountFromServiceShallow(ag.Account),
Group: GroupFromServiceShallow(ag.Group),
}
}
func ProxyFromService(p *service.Proxy) *Proxy {
if p == nil {
return nil
}
return &Proxy{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
if p == nil {
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
Value: rc.Value,
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
OutputTokens: l.OutputTokens,
CacheCreationTokens: l.CacheCreationTokens,
CacheReadTokens: l.CacheReadTokens,
CacheCreation5mTokens: l.CacheCreation5mTokens,
CacheCreation1hTokens: l.CacheCreation1hTokens,
InputCost: l.InputCost,
OutputCost: l.OutputCost,
CacheCreationCost: l.CacheCreationCost,
CacheReadCost: l.CacheReadCost,
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
}
return &Setting{
ID: s.ID,
Key: s.Key,
Value: s.Value,
UpdatedAt: s.UpdatedAt,
}
}
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
if sub == nil {
return nil
}
return &UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
StartsAt: sub.StartsAt,
ExpiresAt: sub.ExpiresAt,
Status: sub.Status,
DailyWindowStart: sub.DailyWindowStart,
WeeklyWindowStart: sub.WeeklyWindowStart,
MonthlyWindowStart: sub.MonthlyWindowStart,
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}
func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult {
if r == nil {
return nil
}
subs := make([]UserSubscription, 0, len(r.Subscriptions))
for i := range r.Subscriptions {
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
}
return &BulkAssignResult{
SuccessCount: r.SuccessCount,
FailedCount: r.FailedCount,
Subscriptions: subs,
Errors: r.Errors,
}
}

View File

@@ -0,0 +1,43 @@
package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -0,0 +1,219 @@
package dto
import "time"
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
RateLimitedAt *time.Time `json:"rate_limited_at"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
GroupIDs []int64 `json:"group_ids,omitempty"`
Groups []*Group `json:"groups,omitempty"`
}
type AccountGroup struct {
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Priority int `json:"priority"`
CreatedAt time.Time `json:"created_at"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"-"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
Type string `json:"type"`
Value float64 `json:"value"`
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`
Value string `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserSubscription struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
StartsAt time.Time `json:"starts_at"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
DailyUsageUSD float64 `json:"daily_usage_usd"`
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}
type BulkAssignResult struct {
SuccessCount int `json:"success_count"`
FailedCount int `json:"failed_count"`
Subscriptions []UserSubscription `json:"subscriptions"`
Errors []string `json:"errors"`
}

View File

@@ -10,7 +10,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -55,7 +54,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -90,8 +89,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满 // 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
@@ -100,10 +99,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 确保在函数退出时减少wait计数 // 确保在函数退出时减少wait计数
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位 // 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -114,7 +113,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 2. 【新增】Wait后二次检查余额/订阅 // 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return return
@@ -151,7 +150,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 3. 获取账号并发槽位 // 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -181,7 +180,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {
@@ -221,7 +220,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return return
@@ -246,7 +245,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
} }
// 余额模式:返回钱包余额 // 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID) latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return return
@@ -264,7 +263,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 逻辑: // 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0 // 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值 // 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 { func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
var remainingValues []float64 var remainingValues []float64
// 检查日限额 // 检查日限额
@@ -357,7 +356,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) _, ok = middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -389,7 +388,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额) // 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额 // 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error()) h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return return
} }

View File

@@ -6,7 +6,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency) result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
} }
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary. // AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) { func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency) result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted) return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
} }
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.

View File

@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
user, ok := middleware2.GetUserFromContext(c) subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok { if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return return
@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full // 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency) maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed // On error, allow request to proceed
@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return return
} }
// Ensure wait count is decremented when function exits // Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot // 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted) userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
// 2. Re-check billing eligibility after wait // 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err) log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return return
@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot // 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted) accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("Account concurrency acquire failed: %v", err) log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted) h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result, Result: result,
ApiKey: apiKey, ApiKey: apiKey,
User: user, User: apiKey.User,
Account: account, Account: account,
Subscription: subscription, Subscription: subscription,
}); err != nil { }); err != nil {

View File

@@ -1,8 +1,9 @@
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code // Redeem handles redeeming a code
// POST /api/v1/redeem // POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) { func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return return
} }
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, result) response.Success(c, dto.RedeemCodeFromService(result))
} }
// GetHistory returns the user's redemption history // GetHistory returns the user's redemption history
// GET /api/v1/redeem/history // GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) { func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
// Default limit is 25 // Default limit is 25
limit := 25 limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, codes) out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
} }

View File

@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
return return
} }
settings.Version = h.version response.Success(c, dto.PublicSettings{
response.Success(c, settings) RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
} }

View File

@@ -1,8 +1,9 @@
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info // SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct { type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"` Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"` Progress *service.SubscriptionProgress `json:"progress"`
} }
@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions // List handles listing current user's subscriptions
// GET /api/v1/subscriptions // GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) { func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not found in context") response.Unauthorized(c, "User not found in context")
return return
} }
u, ok := user.(*model.User) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetActive handles getting current user's active subscriptions // GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active // GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) { func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not found in context") response.Unauthorized(c, "User not found in context")
return return
} }
u, ok := user.(*model.User) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, subscriptions) out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
} }
// GetProgress handles getting subscription progress for current user // GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress // GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) { func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not found in context") response.Unauthorized(c, "User not found in context")
return return
} }
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress // Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue continue
} }
result = append(result, SubscriptionProgressInfo{ result = append(result, SubscriptionProgressInfo{
Subscription: sub, Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress, Progress: progress,
}) })
} }
@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status // GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary // GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) { func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not found in context") response.Unauthorized(c, "User not found in context")
return return
} }
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions // Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -4,10 +4,11 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
// List handles listing usage records with pagination // List handles listing usage records with pagination
// GET /api/v1/usage // GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) { func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records") response.Forbidden(c, "Not authorized to access this API key's usage records")
return return
} }
@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) {
} }
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog var records []service.UsageLog
var result *pagination.PaginationResult var result *pagination.PaginationResult
var err error var err error
if apiKeyID > 0 { if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params) records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else { } else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Paginated(c, records, result.Total, page, pageSize) out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
} }
// GetByID handles getting a single usage record // GetByID handles getting a single usage record
// GET /api/v1/usage/:id // GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) { func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
} }
// 验证所有权 // 验证所有权
if record.UserID != user.ID { if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record") response.Forbidden(c, "Not authorized to access this record")
return return
} }
response.Success(c, record) response.Success(c, dto.UsageLogFromService(record))
} }
// Stats handles getting usage statistics // Stats handles getting usage statistics
// GET /api/v1/usage/stats // GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) { func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found") response.NotFound(c, "API key not found")
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics") response.Forbidden(c, "Not authorized to access this API key's statistics")
return return
} }
@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 { if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else { } else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
} }
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics // DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats // GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) { func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not authenticated") response.Unauthorized(c, "User not authenticated")
return return
} }
user, ok := userValue.(*model.User) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data // DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend // GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) { func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day") granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics // DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models // GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) { func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
startTime, endTime := parseUserTimeRange(c) startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys // DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage // POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
} }
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return

View File

@@ -1,8 +1,9 @@
package handler package handler
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile // GetProfile handles getting user profile
// GET /api/v1/users/me // GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) { func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists { if !ok {
response.Unauthorized(c, "User not authenticated") response.Unauthorized(c, "User not authenticated")
return return
} }
user, ok := userValue.(*model.User) userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段普通用户不应看到备注 // 清空notes字段普通用户不应看到备注
userData.Notes = "" userData.Notes = ""
response.Success(c, userData) response.Success(c, dto.UserFromService(userData))
} }
// ChangePassword handles changing user password // ChangePassword handles changing user password
// POST /api/v1/users/me/password // POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) { func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword, CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword, NewPassword: req.NewPassword,
} }
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile // UpdateProfile handles updating user profile
// PUT /api/v1/users/me // PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) { func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user") subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok { if !ok {
response.InternalError(c, "Invalid user context") response.Unauthorized(c, "User not authenticated")
return return
} }
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username: req.Username, Username: req.Username,
Wechat: req.Wechat, Wechat: req.Wechat,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段普通用户不应看到备注 // 清空notes字段普通用户不应看到备注
updatedUser.Notes = "" updatedUser.Notes = ""
response.Success(c, updatedUser) response.Success(c, dto.UserFromService(updatedUser))
} }

View File

@@ -2,8 +2,8 @@ package infrastructure
import ( import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步) // 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil { if err := repository.AutoMigrate(db); err != nil {
return nil, err return nil, err
} }

View File

@@ -1,450 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"strconv"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]any
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value any) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号包括oauth和setup-token
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息只有oauth类型可以setup-token没有profile权限
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
switch vv := v.(type) {
case string:
return vv
case json.Number:
return vv.String()
case float64:
// JSON numbers decode to float64; keep integer formatting for integer-like values.
i := int64(vv)
if vv == float64(i) {
return strconv.FormatInt(i, 10)
}
return strconv.FormatFloat(vv, 'f', -1, 64)
case float32:
f := float64(vv)
i := int64(f)
if f == float64(i) {
return strconv.FormatInt(i, 10)
}
return strconv.FormatFloat(f, 'f', -1, 64)
case int:
return strconv.FormatInt(int64(vv), 10)
case int64:
return strconv.FormatInt(vv, 10)
case int32:
return strconv.FormatInt(int64(vv), 10)
case uint:
return strconv.FormatUint(uint64(vv), 10)
case uint64:
return strconv.FormatUint(vv, 10)
case uint32:
return strconv.FormatUint(uint64(vv), 10)
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL用于apikey类型账号
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true使用默认策略
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
// IsGemini 检查是否为 Gemini 平台账号
func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号Response 账号)
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com" // OpenAI 默认 API URL
}
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
// GetOpenAIIDToken 获取 OpenAI ID TokenJWT包含用户信息
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
// GetOpenAIApiKey 获取 OpenAI API Key用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
// GetChatGPTAccountID 获取 ChatGPT 账号 ID从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
// GetChatGPTUserID 获取 ChatGPT 用户 ID从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0)
return &t
}
return nil
}
return &t
}
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false // 没有过期时间信息,假设未过期
}
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -1,20 +0,0 @@
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}

View File

@@ -1,32 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}

View File

@@ -1,73 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -1,64 +0,0 @@
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)

View File

@@ -1,45 +0,0 @@
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}

View File

@@ -1,50 +0,0 @@
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -1,104 +0,0 @@
package model
import (
"time"
)
// Setting 系统设置模型Key-Value存储
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -1,67 +0,0 @@
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量4类
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用USD
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -1,78 +0,0 @@
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}

View File

@@ -1,157 +0,0 @@
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度USD基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效状态为active且未过期
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -5,10 +5,10 @@ import (
"errors" "errors"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db} return &accountRepository{db: db}
} }
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error { func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
return r.db.WithContext(ctx).Create(account).Error m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
} }
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var account model.Account var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
// 填充 GroupIDs 和 Groups 虚拟字段 return accountModelToService(&m), nil
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
} }
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" { if crsAccountID == "" {
return nil, nil return nil, nil
} }
var account model.Account var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil return nil, nil
} }
return nil, err return nil, err
} }
return &account, nil return accountModelToService(&m), nil
} }
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error { func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
return r.db.WithContext(ctx).Save(account).Error m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
} }
func (r *accountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系 if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
} }
// 再删除账号 return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { var accounts []accountModel
var accounts []model.Account
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Account{}) db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" { if platform != "" {
db = db.Where("platform = ?", platform) db = db.Where("platform = ?", platform)
} }
@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err return nil, nil, err
} }
// 填充每个 Account 的虚拟字段GroupIDs 和 Groups outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts { for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups)) outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
} }
pages := int(total) / params.Limit() return outAccounts, paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return accounts, &pagination.PaginationResult{ func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
Total: total, var accounts []accountModel
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive). Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
} }
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) { outAccounts := make([]service.Account, 0, len(accounts))
var accounts []model.Account for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
} }
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusError, "status": service.StatusError,
"error_message": errorMsg, "error_message": errorMsg,
}).Error }).Error
} }
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{ ag := &accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: priority, Priority: priority,
@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error Delete(&accountGroupModel{}).Error
} }
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID). Where("account_groups.account_id = ?", accountID).
Find(&groups).Error Find(&groups).Error
return groups, err if err != nil {
return nil, err
} }
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { outGroups := make([]service.Group, 0, len(groups))
var accounts []model.Account for i := range groups {
err := r.db.WithContext(ctx). outGroups = append(outGroups, *groupModelToService(&groups[i]))
Where("platform = ? AND status = ?", platform, model.StatusActive). }
Preload("Proxy"). return outGroups, nil
Order("priority ASC").
Find(&accounts).Error
return accounts, err
} }
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定 if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
} }
// 添加新绑定 if len(groupIDs) == 0 {
if len(groupIDs) > 0 { return nil
accountGroups := make([]model.AccountGroup, 0, len(groupIDs)) }
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs { for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{ accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级 Priority: i + 1,
}) })
} }
return r.db.WithContext(ctx).Create(&accountGroups).Error return r.db.WithContext(ctx).Create(&accountGroups).Error
} }
return nil func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
} var accounts []accountModel
// ListSchedulable 获取所有可调度的账号
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByGroupID 按组获取可调度的账号 func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID). Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByPlatform 按平台获取可调度的账号 func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("platform = ?", platform). Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("priority ASC"). Order("priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { var accounts []accountModel
var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID). Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform). Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy"). Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC"). Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error Find(&accounts).Error
return accounts, err if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
} }
// SetRateLimited 标记账号为限流状态(429)
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": now, "rate_limited_at": now,
"rate_limit_reset_at": resetAt, "rate_limit_reset_at": resetAt,
}).Error }).Error
} }
// SetOverloaded 标记账号为过载状态(529)
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error Update("overload_until", until).Error
} }
// ClearRateLimit 清除账号的限流状态
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
"rate_limit_reset_at": nil, "rate_limit_reset_at": nil,
@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
}).Error }).Error
} }
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil { if end != nil {
updates["session_window_end"] = end updates["session_window_end"] = end
} }
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
} }
// SetSchedulable 设置账号的调度开关
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
} }
// Get current account to preserve existing Extra data var account accountModel
var account model.Account
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil { if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err return err
} }
// Initialize Extra if nil
if account.Extra == nil { if account.Extra == nil {
account.Extra = make(model.JSONB) account.Extra = datatypes.JSONMap{}
} }
// Merge updates into existing Extra
for k, v := range updates { for k, v := range updates {
account.Extra[k] = v account.Extra[k] = v
} }
// Save updated Extra return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("extra", account.Extra).Error Update("extra", account.Extra).Error
} }
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
updateMap["status"] = *updates.Status updateMap["status"] = *updates.Status
} }
if len(updates.Credentials) > 0 { if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials) updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
} }
if len(updates.Extra) > 0 { if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra) updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
} }
if len(updateMap) == 0 { if len(updateMap) == 0 {
@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
} }
result := r.db.WithContext(ctx). result := r.db.WithContext(ctx).
Model(&model.Account{}). Model(&accountModel{}).
Where("id IN ?", ids). Where("id IN ?", ids).
Clauses(clause.Returning{}). Clauses(clause.Returning{}).
Updates(updateMap) Updates(updateMap)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() { func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{ account := &service.Account{
Name: "test-create", Name: "test-create",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Type: model.AccountTypeOAuth, Type: service.AccountTypeOAuth,
Status: model.StatusActive, Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
} }
err := s.repo.Create(s.ctx, account) err := s.repo.Create(s.ctx, account)
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
} }
func (s *AccountRepoSuite) TestUpdate() { func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"}) account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated" account.Name = "updated"
err := s.repo.Update(s.ctx, account) err := s.repo.Update(s.ctx, account)
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
} }
func (s *AccountRepoSuite) TestDelete() { func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
} }
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID) err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings") s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64 var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count) s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed") s.Require().Zero(count, "expected bindings to be removed")
} }
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() { func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
status string status string
search string search string
wantCount int wantCount int
validate func(accounts []model.Account) validate func(accounts []service.Account)
}{ }{
{ {
name: "filter_by_platform", name: "filter_by_platform",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic}) mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI}) mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
}, },
platform: model.PlatformOpenAI, platform: service.PlatformOpenAI,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform) s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
}, },
}, },
{ {
name: "filter_by_type", name: "filter_by_type",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth}) mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey}) mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
}, },
accType: model.AccountTypeApiKey, accType: service.AccountTypeApiKey,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type) s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
}, },
}, },
{ {
name: "filter_by_status", name: "filter_by_status",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive}) mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled}) mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
}, },
status: model.StatusDisabled, status: service.StatusDisabled,
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status) s.Require().Equal(service.StatusDisabled, accounts[0].Status)
}, },
}, },
{ {
name: "filter_by_search", name: "filter_by_search",
setup: func(db *gorm.DB) { setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"}) mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"}) mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
}, },
search: "alpha", search: "alpha",
wantCount: 1, wantCount: 1,
validate: func(accounts []model.Account) { validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha") s.Require().Contains(accounts[0].Name, "alpha")
}, },
}, },
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform --- // --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() { func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive}) acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive}) acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
} }
func (s *AccountRepoSuite) TestListActive() { func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx) accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
} }
func (s *AccountRepoSuite) TestListByPlatform() { func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic) accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform") s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
} }
// --- Preload and VirtualFields --- // --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{ account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1", Name: "acc1",
ProxyID: &proxy.ID, ProxyID: &proxy.ID,
}) })
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups --- // --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup") s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID) groups, err := s.repo.GetGroups(s.ctx, account.ID)
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
} }
func (s *AccountRepoSuite) TestBindGroups_EmptyList() { func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty") s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() { func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx) sched, err := s.repo.ListSchedulable(s.ctx)
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() { func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now() now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute) future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true}) rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited") s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
} }
func (s *AccountRepoSuite) TestListSchedulableByPlatform() { func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
} }
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic) accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(accounts, 1) s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID) s.Require().Equal(a1.ID, accounts[0].ID)
} }
func (s *AccountRepoSuite) TestSetSchedulable() { func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit --- // --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() { func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
} }
func (s *AccountRepoSuite) TestSetRateLimited() { func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC) resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
} }
func (s *AccountRepoSuite) TestClearRateLimit() { func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour) until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed --- // --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() { func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt) s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID)) s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError --- // --- SetError ---
func (s *AccountRepoSuite) TestSetError() { func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status) s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage) s.Require().Equal("something went wrong", got.ErrorMessage)
} }
// --- UpdateSessionWindow --- // --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() { func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC) end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra --- // --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{ account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra", Name: "acc-extra",
Extra: model.JSONB{"a": "1"}, Extra: datatypes.JSONMap{"a": "1"},
}) })
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra") s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
} }
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() { func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
} }
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"})) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID) got, err := s.repo.GetByID(s.ctx, account.ID)
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() { func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345" crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{ mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs", Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID}, Extra: datatypes.JSONMap{"crs_account_id": crsID},
}) })
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID) got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate --- // --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() { func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99 newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{ affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{ a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred", Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"}, Credentials: datatypes.JSONMap{"existing": "value"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"}, Credentials: datatypes.JSONMap{"new_key": "new_value"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() { func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{ a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra", Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"}, Extra: datatypes.JSONMap{"existing": "val"},
}) })
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"}, Extra: datatypes.JSONMap{"new_key": "new_val"},
}) })
s.Require().NoError(err) s.Require().NoError(err)
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
} }
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Zero(affected) s.Require().Zero(affected)
} }
func idsOfAccounts(accounts []model.Account) []int64 { func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts)) out := make([]int64, 0, len(accounts))
for i := range accounts { for i := range accounts {
out = append(out, accounts[i].ID) out = append(out, accounts[i].ID)

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
@@ -14,6 +15,11 @@ const (
apiKeyRateLimitDuration = 24 * time.Hour apiKeyRateLimitDuration = 24 * time.Hour
) )
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
type apiKeyCache struct { type apiKeyCache struct {
rdb *redis.Client rdb *redis.Client
} }
@@ -23,12 +29,16 @@ func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
} }
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
return c.rdb.Get(ctx, key).Int() count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
} }
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key) pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration) pipe.Expire(ctx, key, apiKeyRateLimitDuration)
@@ -37,7 +47,7 @@ func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID in
} }
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }

View File

@@ -23,13 +23,14 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{ }{
{ {
name: "missing_key_returns_redis_nil", name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) { fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1) userID := int64(1)
_, err := cache.GetCreateAttemptCount(ctx, userID) count, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key") require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
}, },
}, },
{ {
@@ -58,8 +59,9 @@ func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID)) require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount") require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
_, err := cache.GetCreateAttemptCount(ctx, userID) count, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete") require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
}, },
}, },
} }

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists) return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var key model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var apiKey model.ApiKey var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return apiKeyModelToService(&m), nil
} }
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
} }
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
} }
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []apiKeyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID) db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outKeys := make([]service.ApiKey, 0, len(keys))
if int(total)%params.Limit() > 0 { for i := range keys {
pages++ outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
} }
return keys, &pagination.PaginationResult{ return outKeys, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []model.ApiKey var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 { if userID > 0 {
db = db.Where("user_id = ?", userID) db = db.Where("user_id = ?", userID)
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err return nil, err
} }
return keys, nil outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey --- // --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() { func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &model.ApiKey{ key := &service.ApiKey{
UserID: user.ID, UserID: user.ID,
Key: "sk-create-test", Key: "sk-create-test",
Name: "Test Key", Name: "Test Key",
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, key) err := s.repo.Create(s.ctx, key)
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ApiKeyRepoSuite) TestGetByKey() { func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-getbykey", Key: "sk-getbykey",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update --- // --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() { func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-update", Key: "sk-update",
Name: "Original", Name: "Original",
Status: model.StatusActive, Status: service.StatusActive,
}) }))
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update") s.Require().NoError(err, "Update")
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal("sk-update", got.Key, "Update should not change key") s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name) s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status) s.Require().Equal(service.StatusDisabled, got.Status)
} }
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-clear-group", Key: "sk-clear-group",
Name: "Group Key", Name: "Group Key",
GroupID: &group.ID, GroupID: &group.ID,
}) }))
key.GroupID = nil key.GroupID = nil
err := s.repo.Update(s.ctx, key) err := s.repo.Update(s.ctx, key)
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete --- // --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() { func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-delete", Key: "sk-delete",
Name: "Delete Me", Name: "Delete Me",
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID --- // --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() { func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID") s.Require().NoError(err, "ListByUserID")
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
} }
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{ mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)), Key: "sk-page-" + string(rune('a'+i)),
Name: "Key", Name: "Key",
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
} }
func (s *ApiKeyRepoSuite) TestCountByUserID() { func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID) count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID") s.Require().NoError(err, "CountByUserID")
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID --- // --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() { func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID") s.Require().NoError(err, "ListByGroupID")
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
} }
func (s *ApiKeyRepoSuite) TestCountByGroupID() { func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID) count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID") s.Require().NoError(err, "CountByGroupID")
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey --- // --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() { func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey") s.Require().NoError(err, "ExistsByKey")
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys --- // --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() { func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys") s.Require().NoError(err, "SearchApiKeys")
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
} }
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err) s.Require().NoError(err)
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID --- // --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID") s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-1", Key: "sk-test-1",
Name: "My Key", Name: "My Key",
GroupID: &group.ID, GroupID: &group.ID,
Status: model.StatusActive, Status: service.StatusActive,
}) }))
got, err := s.repo.GetByKey(s.ctx, key.Key) got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey") s.Require().NoError(err, "GetByKey")
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(group.ID, got.Group.ID) s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed" key.Name = "Renamed"
key.Status = model.StatusDisabled key.Status = service.StatusDisabled
key.GroupID = nil key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update") s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key") s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id") s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name) s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status) s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID) s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID) s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID // ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID, UserID: user.ID,
Key: "sk-test-2", Key: "sk-test-2",
Name: "Group Key", Name: "Group Key",

View File

@@ -0,0 +1,20 @@
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}

View File

@@ -18,6 +18,16 @@ const (
billingCacheTTL = 5 * time.Minute billingCacheTTL = 5 * time.Minute
) )
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const ( const (
subFieldStatus = "status" subFieldStatus = "status"
subFieldExpiresAt = "expires_at" subFieldExpiresAt = "expires_at"
@@ -62,7 +72,7 @@ func NewBillingCache(rdb *redis.Client) service.BillingCache {
} }
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) { func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return 0, err return 0, err
@@ -71,12 +81,12 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
} }
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
} }
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
@@ -85,12 +95,12 @@ func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amou
} }
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) { func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result() result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -140,7 +150,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
return nil return nil
} }
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
fields := map[string]any{ fields := map[string]any{
subFieldStatus: data.Status, subFieldStatus: data.Status,
@@ -159,7 +169,7 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
} }
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) { if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
@@ -168,6 +178,6 @@ func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, grou
} }
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }

View File

@@ -0,0 +1,87 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -11,6 +11,11 @@ import (
const verifyCodeKeyPrefix = "verify_code:" const verifyCodeKeyPrefix = "verify_code:"
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
type emailCache struct { type emailCache struct {
rdb *redis.Client rdb *redis.Client
} }
@@ -20,7 +25,7 @@ func NewEmailCache(rdb *redis.Client) service.EmailCache {
} }
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) { func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -33,7 +38,7 @@ func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*se
} }
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error { func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
val, err := json.Marshal(data) val, err := json.Marshal(data)
if err != nil { if err != nil {
return err return err
@@ -42,6 +47,6 @@ func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data
} }
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error { func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKeyPrefix + email key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }

View File

@@ -0,0 +1,45 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -6,21 +6,25 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
) )
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper() t.Helper()
if u.PasswordHash == "" { if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash" u.PasswordHash = "test-password-hash"
} }
if u.Role == "" { if u.Role == "" {
u.Role = model.RoleUser u.Role = service.RoleUser
} }
if u.Status == "" { if u.Status == "" {
u.Status = model.StatusActive u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
} }
if u.CreatedAt.IsZero() { if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now() u.CreatedAt = time.Now()
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return u return u
} }
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper() t.Helper()
if g.Platform == "" { if g.Platform == "" {
g.Platform = model.PlatformAnthropic g.Platform = service.PlatformAnthropic
} }
if g.Status == "" { if g.Status == "" {
g.Status = model.StatusActive g.Status = service.StatusActive
} }
if g.SubscriptionType == "" { if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard g.SubscriptionType = service.SubscriptionTypeStandard
} }
if g.CreatedAt.IsZero() { if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now() g.CreatedAt = time.Now()
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return g return g
} }
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper() t.Helper()
if p.Protocol == "" { if p.Protocol == "" {
p.Protocol = "http" p.Protocol = "http"
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p.Port = 8080 p.Port = 8080
} }
if p.Status == "" { if p.Status == "" {
p.Status = model.StatusActive p.Status = service.StatusActive
} }
if p.CreatedAt.IsZero() { if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now() p.CreatedAt = time.Now()
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return p return p
} }
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account { func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper() t.Helper()
if a.Platform == "" { if a.Platform == "" {
a.Platform = model.PlatformAnthropic a.Platform = service.PlatformAnthropic
} }
if a.Type == "" { if a.Type == "" {
a.Type = model.AccountTypeOAuth a.Type = service.AccountTypeOAuth
} }
if a.Status == "" { if a.Status == "" {
a.Status = model.StatusActive a.Status = service.StatusActive
} }
if !a.Schedulable { if !a.Schedulable {
a.Schedulable = true a.Schedulable = true
} }
if a.Credentials == nil { if a.Credentials == nil {
a.Credentials = model.JSONB{} a.Credentials = datatypes.JSONMap{}
} }
if a.Extra == nil { if a.Extra == nil {
a.Extra = model.JSONB{} a.Extra = datatypes.JSONMap{}
} }
if a.CreatedAt.IsZero() { if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now() a.CreatedAt = time.Now()
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return a return a
} }
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey { func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper() t.Helper()
if k.Status == "" { if k.Status == "" {
k.Status = model.StatusActive k.Status = service.StatusActive
} }
if k.CreatedAt.IsZero() { if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now() k.CreatedAt = time.Now()
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return k return k
} }
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode { func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper() t.Helper()
if c.Status == "" { if c.Status == "" {
c.Status = model.StatusUnused c.Status = service.StatusUnused
} }
if c.Type == "" { if c.Type == "" {
c.Type = model.RedeemTypeBalance c.Type = service.RedeemTypeBalance
} }
if c.CreatedAt.IsZero() { if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now() c.CreatedAt = time.Now()
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return c return c
} }
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription { func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper() t.Helper()
if s.Status == "" { if s.Status == "" {
s.Status = model.SubscriptionStatusActive s.Status = service.SubscriptionStatusActive
} }
now := time.Now() now := time.Now()
if s.StartsAt.IsZero() { if s.StartsAt.IsZero() {
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper() t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{ require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
Priority: priority, Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group") }).Error, "create account_group")
} }

View File

@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db} return &groupRepository{db: db}
} }
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
err := r.db.WithContext(ctx).Create(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return translatePersistenceError(err, nil, service.ErrGroupExists) return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var group model.Group var m groupModel
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
} }
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
return r.db.WithContext(ctx).Save(group).Error m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
} }
func (r *groupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
} }
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []groupModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Group{}) db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters // Apply filters
if platform != "" { if platform != "" {
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count
} }
pages := int(total) / params.Limit() // 获取每个分组的账号数量
if int(total)%params.Limit() > 0 { for i := range outGroups {
pages++ count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
} }
return groups, &pagination.PaginationResult{ return outGroups, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count
} }
return groups, nil // 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
} }
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []model.Group var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 获取每个分组的账号数量 outGroups := make([]service.Group, 0, len(groups))
for i := range groups { for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID) outGroups = append(outGroups, *groupModelToService(&groups[i]))
groups[i].AccountCount = count
} }
return groups, nil // 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
} }
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64 var affectedUserIDs []int64
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}). Table("user_subscriptions").
Where("group_id = ?", id). Where("group_id = ?", id).
Select("user_id"). Pluck("user_id", &affectedUserIDs).Error; err != nil {
Find(&subscriptions).Error; err != nil {
return nil, err return nil, err
} }
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
} }
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录 // 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
} }
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 3. 从 users.allowed_groups 数组中移除该分组 ID // 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}). if err := tx.Exec(
Where("? = ANY(allowed_groups)", id). "UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { id, id,
).Error; err != nil {
return err return err
} }
// 4. 删除 account_groups 中间表的数据 // 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err return err
} }
// 5. 删除分组本身(带锁,避免并发写) // 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil { if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err return err
} }
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil return affectedUserIDs, nil
} }
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() { func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{ group := &service.Group{
Name: "test-create", Name: "test-create",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, group) err := s.repo.Create(s.ctx, group)
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
} }
func (s *GroupRepoSuite) TestUpdate() { func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"}) group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated" group.Name = "updated"
err := s.repo.Update(s.ctx, group) err := s.repo.Update(s.ctx, group)
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
} }
func (s *GroupRepoSuite) TestDelete() { func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID) err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() { func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
} }
func (s *GroupRepoSuite) TestListWithFilters_Platform() { func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform) s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
} }
func (s *GroupRepoSuite) TestListWithFilters_Status() { func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status) s.Require().Equal(service.StatusDisabled, groups[0].Status)
} }
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
} }
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{ g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1", Name: "g1",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
}) })
g2 := mustCreateGroup(s.T(), s.db, &model.Group{ g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2", Name: "g2",
Platform: model.PlatformAnthropic, Platform: service.PlatformAnthropic,
Status: model.StatusActive, Status: service.StatusActive,
IsExclusive: true, IsExclusive: true,
}) })
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive) groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform --- // --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() { func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx) groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
} }
func (s *GroupRepoSuite) TestListActiveByPlatform() { func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic) groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform") s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1) s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name) s.Require().Equal("g1", groups[0].Name)
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName --- // --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() { func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"}) mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group") exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName") s.Require().NoError(err, "ExistsByName")
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount --- // --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() { func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
} }
func (s *GroupRepoSuite) TestGetAccountCount_Empty() { func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID) count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err) s.Require().NoError(err)
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID --- // --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
} }
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"}) g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2) mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3) mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)

View File

@@ -15,6 +15,11 @@ const (
fingerprintTTL = 24 * time.Hour fingerprintTTL = 24 * time.Hour
) )
// fingerprintKey generates the Redis key for account fingerprint cache.
func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
type identityCache struct { type identityCache struct {
rdb *redis.Client rdb *redis.Client
} }
@@ -24,7 +29,7 @@ func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
} }
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) { func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fingerprintKey(accountID)
val, err := c.rdb.Get(ctx, key).Result() val, err := c.rdb.Get(ctx, key).Result()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -37,7 +42,7 @@ func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*s
} }
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error { func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
key := fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) key := fingerprintKey(accountID)
val, err := json.Marshal(fp) val, err := json.Marshal(fp)
if err != nil { if err != nil {
return err return err

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestFingerprintKey(t *testing.T) {
tests := []struct {
name string
accountID int64
expected string
}{
{
name: "normal_account_id",
accountID: 123,
expected: "fingerprint:123",
},
{
name: "zero_account_id",
accountID: 0,
expected: "fingerprint:0",
},
{
name: "negative_account_id",
accountID: -1,
expected: "fingerprint:-1",
},
{
name: "max_int64",
accountID: math.MaxInt64,
expected: "fingerprint:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := fingerprintKey(tc.accountID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -15,7 +15,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open gorm db: %v", err) log.Printf("failed to open gorm db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := model.AutoMigrate(integrationDB); err != nil { if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err) log.Printf("failed to automigrate db: %v", err)
os.Exit(1) os.Exit(1)
} }

View File

@@ -0,0 +1,16 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}

View File

@@ -2,10 +2,10 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var proxy model.Proxy var m proxyModel
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil) return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return proxyModelToService(&m), nil
} }
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
} }
func (r *proxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
} }
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []proxyModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{}) db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters // Apply filters
if protocol != "" { if protocol != "" {
@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outProxies := make([]service.Proxy, 0, len(proxies))
if int(total)%params.Limit() > 0 { for i := range proxies {
pages++ outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
} }
return proxies, &pagination.PaginationResult{ return outProxies, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
return proxies, err if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return outProxies, nil
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error Count(&count).Error
if err != nil { if err != nil {
@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
Count(&count).Error Count(&count).Error
return count, err return count, err
@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
var results []result var results []result
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Model(&model.Account{}). Table("accounts").
Select("proxy_id, COUNT(*) as count"). Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL"). Where("proxy_id IS NOT NULL").
Group("proxy_id"). Group("proxy_id").
@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []proxyModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Order("created_at DESC"). Order("created_at DESC").
Find(&proxies).Error Find(&proxies).Error
if err != nil { if err != nil {
@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
} }
// Build result with account counts // Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies)) result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i, proxy := range proxies { for i := range proxies {
result[i] = model.ProxyWithAccountCount{ proxy := proxyModelToService(&proxies[i])
Proxy: proxy, if proxy == nil {
AccountCount: counts[proxy.ID], continue
} }
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
})
} }
return result, nil return result, nil
} }
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() { func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{ proxy := &service.Proxy{
Name: "test-create", Name: "test-create",
Protocol: "http", Protocol: "http",
Host: "127.0.0.1", Host: "127.0.0.1",
Port: 8080, Port: 8080,
Status: model.StatusActive, Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, proxy) err := s.repo.Create(s.ctx, proxy)
@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
} }
func (s *ProxyRepoSuite) TestUpdate() { func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"}) proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy.Name = "updated" proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy) err := s.repo.Update(s.ctx, proxy)
@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
} }
func (s *ProxyRepoSuite) TestDelete() { func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID) err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() { func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err) s.Require().NoError(err)
@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
} }
func (s *ProxyRepoSuite) TestListWithFilters_Status() { func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(proxies, 1) s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status) s.Require().Equal(service.StatusDisabled, proxies[0].Status)
} }
func (s *ProxyRepoSuite) TestListWithFilters_Search() { func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod") proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err) s.Require().NoError(err)
@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive --- // --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() { func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled}) mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx) proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive") s.Require().NoError(err, "ListActive")
@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth --- // --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
} }
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p-noauth", Name: "p-noauth",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID --- // --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")
@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
} }
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"}) proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err) s.Require().NoError(err)
@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies --- // --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() { func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx) counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies") s.Require().NoError(err, "GetAccountCountsForProxies")
@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour), CreatedAt: base.Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Status: model.StatusActive, Status: service.StatusActive,
CreatedAt: base, CreatedAt: base,
}) })
mustCreateProxy(s.T(), s.db, &model.Proxy{ mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive", Name: "p3-inactive",
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount") s.Require().NoError(err, "ListActiveWithAccountCount")
@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test --- // --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1", Name: "p1",
Protocol: "http", Protocol: "http",
Host: "1.2.3.4", Host: "1.2.3.4",
@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
CreatedAt: time.Now().Add(-1 * time.Hour), CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour), UpdatedAt: time.Now().Add(-1 * time.Hour),
}) })
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2", Name: "p2",
Protocol: "http", Protocol: "http",
Host: "5.6.7.8", Host: "5.6.7.8",
@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
s.Require().NoError(err, "ExistsByHostPortAuth") s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist") s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID) count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID") s.Require().NoError(err, "CountAccountsByProxyID")

View File

@@ -15,6 +15,16 @@ const (
redeemRateLimitDuration = 24 * time.Hour redeemRateLimitDuration = 24 * time.Hour
) )
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
func redeemRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
}
// redeemLockKey generates the Redis key for redeem code locking.
func redeemLockKey(code string) string {
return redeemLockKeyPrefix + code
}
type redeemCache struct { type redeemCache struct {
rdb *redis.Client rdb *redis.Client
} }
@@ -24,12 +34,16 @@ func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
} }
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) { func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) key := redeemRateLimitKey(userID)
return c.rdb.Get(ctx, key).Int() count, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
return count, err
} }
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error { func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID) key := redeemRateLimitKey(userID)
pipe := c.rdb.Pipeline() pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key) pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration) pipe.Expire(ctx, key, redeemRateLimitDuration)
@@ -38,11 +52,11 @@ func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID in
} }
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) { func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKeyPrefix + code key := redeemLockKey(code)
return c.rdb.SetNX(ctx, key, 1, ttl).Result() return c.rdb.SetNX(ctx, key, 1, ttl).Result()
} }
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error { func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKeyPrefix + code key := redeemLockKey(code)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }

View File

@@ -3,12 +3,10 @@
package repository package repository
import ( import (
"errors"
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
@@ -25,9 +23,9 @@ func (s *RedeemCacheSuite) SetupTest() {
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() { func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999) missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID) count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key") require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil)) require.Equal(s.T(), 0, count, "expected zero count for missing key")
} }
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() { func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {

View File

@@ -0,0 +1,77 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedeemRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "redeem:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "redeem:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "redeem:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "redeem:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestRedeemLockKey(t *testing.T) {
tests := []struct {
name string
code string
expected string
}{
{
name: "normal_code",
code: "ABC123",
expected: "redeem:lock:ABC123",
},
{
name: "empty_code",
code: "",
expected: "redeem:lock:",
},
{
name: "code_with_special_chars",
code: "CODE-2024:test",
expected: "redeem:lock:CODE-2024:test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemLockKey(tc.code)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -4,10 +4,8 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
}
return r.db.WithContext(ctx).Create(&models).Error
} }
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var code model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var redeemCode model.RedeemCode var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return redeemCodeModelToService(&m), nil
} }
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
} }
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { var codes []redeemCodeModel
var codes []model.RedeemCode
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{}) db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
// Apply filters
if codeType != "" { if codeType != "" {
db = db.Where("type = ?", codeType) db = db.Where("type = ?", codeType)
} }
@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() outCodes := make([]service.RedeemCode, 0, len(codes))
if int(total)%params.Limit() > 0 { for i := range codes {
pages++ outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
} }
return codes, &pagination.PaginationResult{ return outCodes, paginationResultFromTotal(total, params), nil
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
} }
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusUsed, "status": service.StatusUsed,
"used_by": userID, "used_by": userID,
"used_at": now, "used_at": now,
}) })
@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10
} }
var codes []redeemCodeModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("used_by = ?", userID). Where("used_by = ?", userID).
Order("used_at DESC"). Order("used_at DESC").
Limit(limit). Limit(limit).
Find(&codes).Error Find(&codes).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return codes, nil
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
} }

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
// --- Create / CreateBatch / GetByID / GetByCode --- // --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() { func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{ code := &service.RedeemCode{
Code: "TEST-CREATE", Code: "TEST-CREATE",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Value: 100, Value: 100,
Status: model.StatusUnused, Status: service.StatusUnused,
} }
err := s.repo.Create(s.ctx, code) err := s.repo.Create(s.ctx, code)
@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
} }
func (s *RedeemCodeRepoSuite) TestCreateBatch() { func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused}, {Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused}, {Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
} }
err := s.repo.CreateBatch(s.ctx, codes) err := s.repo.CreateBatch(s.ctx, codes)
@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
} }
func (s *RedeemCodeRepoSuite) TestGetByCode() { func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE") got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
// --- Delete --- // --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() { func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID) err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() { func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type) s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(codes, 1) s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status) s.Require().Equal(service.StatusUsed, codes[0].Status)
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance}) mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha") codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err) s.Require().NoError(err)
@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
} }
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP", Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
GroupID: &group.ID, GroupID: &group.ID,
}) })
@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update --- // --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() { func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10}) code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code.Value = 50 code.Value = 50
err := s.repo.Update(s.ctx, code) err := s.repo.Update(s.ctx, code)
@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use --- // --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() { func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use") s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID) got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status) s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy) s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy) s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt) s.Require().NotNil(got.UsedAt)
} }
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time") s.Require().NoError(err, "Use first time")
@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser --- // --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() { func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering // Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1", Code: "USER-1",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c1).Update("used_at", base) s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2", Code: "USER-2",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour)) s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP", Code: "WITH-GRP",
Type: model.RedeemTypeSubscription, Type: service.RedeemTypeSubscription,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
GroupID: &group.ID, GroupID: &group.ID,
}) })
@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
} }
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM", Code: "DEF-LIM",
Type: model.RedeemTypeBalance, Type: service.RedeemTypeBalance,
Status: model.StatusUsed, Status: service.StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
}) })
s.db.Model(c).Update("used_at", time.Now()) s.db.Model(c).Update("used_at", time.Now())
@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test --- // --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() { func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
codes := []model.RedeemCode{ codes := []service.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()}, {Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, {Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
} }
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch") s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code") list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total) s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1) s.Require().Len(list, 1)
@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering // Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA") s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10) used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser") s.Require().NoError(err, "ListByUser")

View File

@@ -6,33 +6,27 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层
type settingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) service.SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { var m settingModel
var setting model.Setting err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil) return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return settingModelToService(&m), nil
} }
// GetValue 获取设置值字符串
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
return setting.Value, nil return setting.Value, nil
} }
// Set 设置值(存在则更新,不存在则创建)
func (r *settingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{ return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error }).Create(m).Error
} }
// GetMultiple 批量获取设置
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
return result, nil return result, nil
} }
// SetMultiple 批量设置值
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ m := &settingModel{
Key: key, Key: key,
Value: value, Value: value,
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
if err := tx.Clauses(clause.OnConflict{ if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}}, Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil { }).Create(m).Error; err != nil {
return err return err
} }
} }
@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
}) })
} }
// GetAll 获取所有设置
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
return nil, err return nil, err
@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
return result, nil return result, nil
} }
// Delete 删除设置
func (r *settingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
} }

View File

@@ -6,7 +6,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
TokenCount int64 `gorm:"column:token_count"` TokenCount int64 `gorm:"column:token_count"`
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as request_count, COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error m := usageLogModelFromService(log)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUsageLogModelToService(log, m)
}
return err
} }
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
var log model.UsageLog var log usageLogModel
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil) return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return usageLogModelToService(&log), nil
} }
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return logs, &pagination.PaginationResult{ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
Total: total, var logs []usageLogModel
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UserStats 用户使用统计 // UserStats 用户使用统计
@@ -125,7 +109,7 @@ type UserStats struct {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
today := timezone.Today() today := timezone.Today()
// 总用户数 // 总用户数
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers) r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
// 今日新增用户数 // 今日新增用户数
r.db.WithContext(ctx).Model(&model.User{}). r.db.WithContext(ctx).Model(&userModel{}).
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.TodayNewUsers) Count(&stats.TodayNewUsers)
// 今日活跃用户数 (今日有请求的用户) // 今日活跃用户数 (今日有请求的用户)
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Distinct("user_id"). Distinct("user_id").
Where("created_at >= ?", today). Where("created_at >= ?", today).
Count(&stats.ActiveUsers) Count(&stats.ActiveUsers)
// 总 API Key 数 // 总 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys) r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
// 活跃 API Key 数 // 活跃 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("status = ?", model.StatusActive). Where("status = ?", service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 总账户数 // 总账户数
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts) r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
// 正常账户数 (schedulable=true, status=active) // 正常账户数 (schedulable=true, status=active)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ? AND schedulable = ?", model.StatusActive, true). Where("status = ? AND schedulable = ?", service.StatusActive, true).
Count(&stats.NormalAccounts) Count(&stats.NormalAccounts)
// 异常账户数 (status=error) // 异常账户数 (status=error)
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ?", model.StatusError). Where("status = ?", service.StatusError).
Count(&stats.ErrorAccounts) Count(&stats.ErrorAccounts)
// 限流账户数 // 限流账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()). Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
Count(&stats.RateLimitAccounts) Count(&stats.RateLimitAccounts)
// 过载账户数 // 过载账户数
r.db.WithContext(ctx).Model(&model.Account{}). r.db.WithContext(ctx).Model(&accountModel{}).
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()). Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
Count(&stats.OverloadAccounts) Count(&stats.OverloadAccounts)
@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID) db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
if err := db.Count(&total).Error; err != nil { if err := db.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return logs, &pagination.PaginationResult{ func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
Total: total, var logs []usageLogModel
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
Order("id DESC"). Order("id DESC").
Find(&logs).Error Find(&logs).Error
return logs, nil, err return usageLogModelsToService(logs), nil, err
} }
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
Cost float64 `gorm:"column:cost"` Cost float64 `gorm:"column:cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today() today := timezone.Today()
// API Key 统计 // API Key 统计
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Count(&stats.TotalApiKeys) Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&model.ApiKey{}). r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ? AND status = ?", userID, model.StatusActive). Where("user_id = ? AND status = ?", userID, service.StatusActive).
Count(&stats.ActiveApiKeys) Count(&stats.ActiveApiKeys)
// 累计 Token 统计 // 累计 Token 统计
@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TotalActualCost float64 `gorm:"column:total_actual_cost"` TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as today_requests, COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens, COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []usageLogModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}) db := r.db.WithContext(ctx).Model(&usageLogModel{})
// Apply filters // Apply filters
if filters.UserID > 0 { if filters.UserID > 0 {
@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
} }
// UsageStats represents usage statistics // UsageStats represents usage statistics
@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("user_id IN ?", userIDs). Where("user_id IN ?", userIDs).
Group("user_id"). Group("user_id").
@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"` UserID int64 `gorm:"column:user_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("user_id IN ? AND created_at >= ?", userIDs, today). Where("user_id IN ? AND created_at >= ?", userIDs, today).
Group("user_id"). Group("user_id").
@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TotalCost float64 `gorm:"column:total_cost"` TotalCost float64 `gorm:"column:total_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("api_key_id IN ?", apiKeyIDs). Where("api_key_id IN ?", apiKeyIDs).
Group("api_key_id"). Group("api_key_id").
@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"` ApiKeyID int64 `gorm:"column:api_key_id"`
TodayCost float64 `gorm:"column:today_cost"` TodayCost float64 `gorm:"column:today_cost"`
} }
err = r.db.WithContext(ctx).Model(&model.UsageLog{}). err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today). Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
Group("api_key_id"). Group("api_key_id").
@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
dateFormat = "YYYY-MM-DD" dateFormat = "YYYY-MM-DD"
} }
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, ?) as date, TO_CHAR(created_at, ?) as date,
COUNT(*) as requests, COUNT(*) as requests,
@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
model, model,
COUNT(*) as requests, COUNT(*) as requests,
@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs float64 `gorm:"column:avg_duration_ms"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
COUNT(*) as total_requests, COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost float64 `gorm:"column:actual_cost"` ActualCost float64 `gorm:"column:actual_cost"`
} }
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(` Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date, TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests, COUNT(*) as requests,
@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var avgDuration struct { var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"` AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
} }
r.db.WithContext(ctx).Model(&model.UsageLog{}). r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms"). Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration) Scan(&avgDuration)
@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Models: models, Models: models,
}, nil }, nil
} }
type usageLogModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
ApiKeyID int64 `gorm:"index;not null"`
AccountID int64 `gorm:"index;not null"`
RequestID string `gorm:"size:64"`
Model string `gorm:"size:100;index;not null"`
GroupID *int64 `gorm:"index"`
SubscriptionID *int64 `gorm:"index"`
InputTokens int `gorm:"default:0;not null"`
OutputTokens int `gorm:"default:0;not null"`
CacheCreationTokens int `gorm:"default:0;not null"`
CacheReadTokens int `gorm:"default:0;not null"`
CacheCreation5mTokens int `gorm:"default:0;not null"`
CacheCreation1hTokens int `gorm:"default:0;not null"`
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
BillingType int8 `gorm:"type:smallint;default:0;not null"`
Stream bool `gorm:"default:false;not null"`
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time `gorm:"index;not null"`
User *userModel `gorm:"foreignKey:UserID"`
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
}
func (usageLogModel) TableName() string { return "usage_logs" }
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
if m == nil {
return nil
}
return &service.UsageLog{
ID: m.ID,
UserID: m.UserID,
ApiKeyID: m.ApiKeyID,
AccountID: m.AccountID,
RequestID: m.RequestID,
Model: m.Model,
GroupID: m.GroupID,
SubscriptionID: m.SubscriptionID,
InputTokens: m.InputTokens,
OutputTokens: m.OutputTokens,
CacheCreationTokens: m.CacheCreationTokens,
CacheReadTokens: m.CacheReadTokens,
CacheCreation5mTokens: m.CacheCreation5mTokens,
CacheCreation1hTokens: m.CacheCreation1hTokens,
InputCost: m.InputCost,
OutputCost: m.OutputCost,
CacheCreationCost: m.CacheCreationCost,
CacheReadCost: m.CacheReadCost,
TotalCost: m.TotalCost,
ActualCost: m.ActualCost,
RateMultiplier: m.RateMultiplier,
BillingType: m.BillingType,
Stream: m.Stream,
DurationMs: m.DurationMs,
FirstTokenMs: m.FirstTokenMs,
CreatedAt: m.CreatedAt,
User: userModelToService(m.User),
ApiKey: apiKeyModelToService(m.ApiKey),
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
Subscription: userSubscriptionModelToService(m.Subscription),
}
}
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
out := make([]service.UsageLog, 0, len(models))
for i := range models {
if s := usageLogModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
if log == nil {
return nil
}
return &usageLogModel{
ID: log.ID,
UserID: log.UserID,
ApiKeyID: log.ApiKeyID,
AccountID: log.AccountID,
RequestID: log.RequestID,
Model: log.Model,
GroupID: log.GroupID,
SubscriptionID: log.SubscriptionID,
InputTokens: log.InputTokens,
OutputTokens: log.OutputTokens,
CacheCreationTokens: log.CacheCreationTokens,
CacheReadTokens: log.CacheReadTokens,
CacheCreation5mTokens: log.CacheCreation5mTokens,
CacheCreation1hTokens: log.CacheCreation1hTokens,
InputCost: log.InputCost,
OutputCost: log.OutputCost,
CacheCreationCost: log.CacheCreationCost,
CacheReadCost: log.CacheReadCost,
TotalCost: log.TotalCost,
ActualCost: log.ActualCost,
RateMultiplier: log.RateMultiplier,
BillingType: log.BillingType,
Stream: log.Stream,
DurationMs: log.DurationMs,
FirstTokenMs: log.FirstTokenMs,
CreatedAt: log.CreatedAt,
}
}
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
if log == nil || m == nil {
return
}
log.ID = m.ID
log.CreatedAt = m.CreatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite)) suite.Run(t, new(UsageLogRepoSuite))
} }
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog { func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
// --- Create / GetByID --- // --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() { func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
log := &model.UsageLog{ log := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
} }
func (s *UsageLogRepoSuite) TestGetByID() { func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete --- // --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() { func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser --- // --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() { func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey --- // --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() { func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount --- // --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() { func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats --- // --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() { func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters --- // --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() { func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now() now := time.Now()
todayStart := timezone.Today() todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{ userToday := mustCreateUser(s.T(), s.db, &userModel{
Email: "today@example.com", Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)), CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now, UpdatedAt: now,
}) })
userOld := mustCreateUser(s.T(), s.db, &model.User{ userOld := mustCreateUser(s.T(), s.db, &userModel{
Email: "old@example.com", Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour), CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour), UpdatedAt: todayStart.Add(-24 * time.Hour),
}) })
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled}) mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute) resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true}) accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300 d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{ logToday := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{ logOld := &service.UsageLog{
UserID: userOld.ID, UserID: userOld.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
} }
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{ logPerf := &service.UsageLog{
UserID: userToday.ID, UserID: userToday.ID,
ApiKeyID: apiKey1.ID, ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID, AccountID: accNormal.ID,
@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats --- // --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats --- // --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats --- // --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats --- // --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats --- // --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() { func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange --- // --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange --- // --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange --- // --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange --- // --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log2)) s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{ log3 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats --- // --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
now := time.Now() now := time.Now()
windowStart := now.Add(-10 * time.Minute) windowStart := now.Add(-10 * time.Minute)
@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID --- // --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
} }
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats --- // --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() { func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models // Create logs with different models
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters --- // --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
} }
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters --- // --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats --- // --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days // Create logs on different days
log1 := &model.UsageLog{ log1 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
s.Require().NoError(s.repo.Create(s.ctx, log1)) s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{ log2 := &service.UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
} }
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base startTime := base
@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend --- // --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend --- // --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
} }
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) --- // --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
} }
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"}) account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)

View File

@@ -2,12 +2,13 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db} return &userRepository{db: db}
} }
func (r *userRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Create(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
} }
func (r *userRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *service.User) error {
err := r.db.WithContext(ctx).Save(user).Error m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *userRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
} }
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []model.User var users []userModel
var total int64 var total int64
db := r.db.WithContext(ctx).Model(&model.User{}) db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters // Apply filters
if status != "" { if status != "" {
@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Batch load subscriptions for all users (avoid N+1) // Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 { if len(users) > 0 {
userIDs := make([]int64, len(users)) userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users)) userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users { for i := range users {
userIDs[i] = users[i].ID userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i] u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
} }
// Query active subscriptions with groups in one query // Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx). if err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive). Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil { Find(&subscriptions).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Associate subscriptions with users // Associate subscriptions with users
for i := range subscriptions { for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok { if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i]) user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
}
} }
} }
pages := int(total) / params.Limit() return outUsers, paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return users, &pagination.PaginationResult{ outUsers := make([]service.User, 0, len(users))
Total: total, for i := range users {
Page: params.Page, outUsers = append(outUsers, *userModelToService(&users[i]))
PageSize: params.Limit(), }
Pages: pages,
}, nil return outUsers, paginationResultFromTotal(total, params), nil
} }
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil { if result.Error != nil {
@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
} }
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
} }
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var user model.User var m userModel
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return userModelToService(&m), nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
} }

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByEmail / Update / Delete --- // --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() { func (s *UserRepoSuite) TestCreate() {
user := &model.User{ user := &service.User{
Email: "create@test.com", Email: "create@test.com",
Username: "testuser", Username: "testuser",
Role: model.RoleUser, PasswordHash: "test-password-hash",
Status: model.StatusActive, Role: service.RoleUser,
Status: service.StatusActive,
} }
err := s.repo.Create(s.ctx, user) err := s.repo.Create(s.ctx, user)
@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserRepoSuite) TestGetByEmail() { func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email) got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail") s.Require().NoError(err, "GetByEmail")
@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
} }
func (s *UserRepoSuite) TestUpdate() { func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"}) user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated" user.Username = "updated"
err := s.repo.Update(s.ctx, user) err := s.repo.Update(s.ctx, user)
@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
} }
func (s *UserRepoSuite) TestDelete() { func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID) err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete") s.Require().NoError(err, "Delete")
@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters --- // --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() { func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List") s.Require().NoError(err, "List")
@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
} }
func (s *UserRepoSuite) TestListWithFilters_Status() { func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive}) mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled}) mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status) s.Require().Equal(service.StatusActive, users[0].Status)
} }
func (s *UserRepoSuite) TestListWithFilters_Role() { func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser}) mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(users, 1) s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role) s.Require().Equal(service.RoleAdmin, users[0].Role)
} }
func (s *UserRepoSuite) TestListWithFilters_Search() { func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"}) mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"}) mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err) s.Require().NoError(err)
@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"}) mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err) s.Require().NoError(err)
@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
} }
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"}) mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello") users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err) s.Require().NoError(err)
@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
} }
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
}) })
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour), ExpiresAt: time.Now().Add(-1 * time.Hour),
}) })
@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
} }
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
target := mustCreateUser(s.T(), s.db, &model.User{ target := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations --- // --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() { func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5) err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance") s.Require().NoError(err, "UpdateBalance")
@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
} }
func (s *UserRepoSuite) TestUpdateBalance_Negative() { func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3) err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative") s.Require().NoError(err, "UpdateBalance with negative")
@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
} }
func (s *UserRepoSuite) TestDeductBalance() { func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5) err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance") s.Require().NoError(err, "DeductBalance")
@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
} }
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10) err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount") s.Require().NoError(err, "DeductBalance exact amount")
@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
// --- Concurrency --- // --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() { func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3) err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency") s.Require().NoError(err, "UpdateConcurrency")
@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
} }
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2) err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative") s.Require().NoError(err, "UpdateConcurrency negative")
@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail --- // --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() { func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com") exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail") s.Require().NoError(err, "ExistsByEmail")
@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42) groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{ userA := mustCreateUser(s.T(), s.db, &userModel{
Email: "a1@example.com", Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7}, AllowedGroups: pq.Int64Array{groupID, 7},
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "a2@example.com", Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7}, AllowedGroups: pq.Int64Array{7},
}) })
@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
} }
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "nomatch@test.com", Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3}, AllowedGroups: pq.Int64Array{1, 2, 3},
}) })
@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin --- // --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{ admin1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "admin1@example.com", Email: "admin1@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "admin2@example.com", Email: "admin2@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "user@example.com", Email: "user@example.com",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
}) })
_, err := s.repo.GetFirstAdmin(s.ctx) _, err := s.repo.GetFirstAdmin(s.ctx)
@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
} }
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{ mustCreateUser(s.T(), s.db, &userModel{
Email: "disabled@example.com", Email: "disabled@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{ activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
Email: "active@example.com", Email: "active@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
}) })
got, err := s.repo.GetFirstAdmin(s.ctx) got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{ user1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com", Email: "a@example.com",
Username: "Alice", Username: "Alice",
Wechat: "wx_a", Wechat: "wx_a",
Role: model.RoleUser, Role: service.RoleUser,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 10, Balance: 10,
}) })
user2 := mustCreateUser(s.T(), s.db, &model.User{ user2 := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com", Email: "b@example.com",
Username: "Bob", Username: "Bob",
Wechat: "wx_b", Wechat: "wx_b",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusActive, Status: service.StatusActive,
Balance: 1, Balance: 1,
}) })
_ = mustCreateUser(s.T(), s.db, &model.User{ _ = mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com", Email: "c@example.com",
Role: model.RoleAdmin, Role: service.RoleAdmin,
Status: model.StatusDisabled, Status: service.StatusDisabled,
}) })
got, err := s.repo.GetByID(s.ctx, user1.ID) got, err := s.repo.GetByID(s.ctx, user1.ID)
@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch") s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10} params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@") users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters") s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -4,111 +4,113 @@ import (
"context" "context"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库
type userSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists) return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
Preload("Group"). Preload("Group").
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&m, id).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var m userSubscriptionModel
var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&m).Error
if err != nil { if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return userSubscriptionModelToService(&m), nil
} }
// Update 更新订阅 func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return err
} }
// Delete 删除订阅
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ?", userID). Where("user_id = ?", userID).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListActiveByUserID 获取用户的所有有效订阅 func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?", Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()). userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC"). Order("created_at DESC").
Find(&subs).Error Find(&subs).Error
return subs, err if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// ListByGroupID 获取分组的所有订阅(分页) func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []userSubscriptionModel
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return subs, &pagination.PaginationResult{ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
Total: total, var subs []userSubscriptionModel
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
}
// List 获取所有订阅(分页,支持筛选)
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64 var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}) query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
if userID != nil { if userID != nil {
query = query.Where("user_id = ?", *userID) query = query.Where("user_id = ?", *userID)
} }
@@ -170,22 +160,87 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err return nil, nil, err
} }
pages := int(total) / params.Limit() return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
if int(total)%params.Limit() > 0 {
pages++
} }
return subs, &pagination.PaginationResult{ func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
Total: total, var count int64
Page: params.Page, err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
PageSize: params.Limit(), Where("user_id = ? AND group_id = ?", userID, groupID).
Pages: pages, Count(&count).Error
}, nil return count > 0, err
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": start,
"weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
} }
// IncrementUsage 增加使用量
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD), "daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
@@ -195,131 +250,150 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}).Error }).Error
} }
// ResetDailyUsage 重置日使用量
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetWeeklyUsage 重置周使用量
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetMonthlyUsage 重置月使用量
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": activateTime,
"weekly_window_start": activateTime,
"monthly_window_start": activateTime,
"updated_at": time.Now(),
}).Error
}
// UpdateStatus 更新订阅状态
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
// ExtendExpiry 延长订阅过期时间
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
// UpdateNotes 更新订阅备注
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
"status": model.SubscriptionStatusExpired, "status": service.SubscriptionStatusExpired,
"updated_at": time.Now(), "updated_at": time.Now(),
}) })
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). var subs []userSubscriptionModel
Where("user_id = ? AND group_id = ?", userID, groupID). err := r.db.WithContext(ctx).
Count(&count).Error Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
return count > 0, err Find(&subs).Error
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
} }
// CountByGroupID 获取分组的订阅数量
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()). groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete --- // --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() { func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
sub := &model.UserSubscription{ sub := &service.UserSubscription{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
} }
@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
} }
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() { func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID, AssignedBy: &admin.ID,
}) })
@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdate() { func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) }))
sub.Notes = "updated notes" sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub) err := s.repo.Update(s.ctx, sub)
@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
} }
func (s *UserSubscriptionRepoSuite) TestDelete() { func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID --- // --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
// Create active subscription (future expiry) // Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() { func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
// Create expired subscription (past expiry but active status) // Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID --- // --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() { func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() { func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID) subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID") s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status) s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
} }
// --- ListByGroupID --- // --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() { func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters --- // --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"}) g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"}) g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g1.ID, GroupID: g1.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: g2.ID, GroupID: g2.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired) subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Len(subs, 1) s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
} }
// --- Usage tracking --- // --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
} }
func (s *UserSubscriptionRepoSuite) TestActivateWindows() { func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
} }
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0, DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0, WeeklyUsageUSD: 20.0,
@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0, WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0, MonthlyUsageUSD: 30.0,
@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
} }
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0, MonthlyUsageUSD: 100.0,
}) })
@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
// --- UpdateStatus / ExtendExpiry / UpdateNotes --- // --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() { func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired) err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus") s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID) got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status) s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
} }
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
} }
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus --- // --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() { func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
} }
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
s.Require().Equal(int64(1), affected) s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID) gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status) s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID) gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status) s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
} }
// --- ExistsByUserIDAndGroupID --- // --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID --- // --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
} }
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"}) user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"}) user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID, UserID: user1.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID, UserID: user2.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
}) })
@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID --- // --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour), ExpiresAt: time.Now().Add(24 * time.Hour),
}) })
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusExpired, Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour), ExpiresAt: time.Now().Add(-24 * time.Hour),
}) })
@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
// --- Combined original test --- // --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() { func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"}) user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"}) group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour), ExpiresAt: time.Now().Add(2 * time.Hour),
}) })
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID, UserID: user.ID,
GroupID: group.ID, GroupID: group.ID,
Status: model.SubscriptionStatusActive, Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour), ExpiresAt: time.Now().Add(-2 * time.Hour),
}) })
@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(int64(1), affected, "expected 1 affected row") s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID) updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired") s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -84,7 +83,11 @@ func validateAdminApiKey(
return false return false
} }
c.Set(string(ContextKeyUser), admin) c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key") c.Set("auth_method", "admin_api_key")
return true return true
} }
@@ -121,12 +124,16 @@ func validateJWTForAdmin(
} }
// 检查管理员权限 // 检查管理员权限
if user.Role != model.RoleAdmin { if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false return false
} }
c.Set(string(ContextKeyUser), user) c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt") c.Set("auth_method", "jwt")
return true return true

View File

@@ -1,7 +1,7 @@
package middleware package middleware
import ( import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -10,15 +10,14 @@ import (
// 必须在JWTAuth中间件之后使用 // 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc { func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 从上下文获取用户 role, ok := GetUserRoleFromContext(c)
user, exists := GetUserFromContext(c) if !ok {
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context") AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return return
} }
// 检查是否为管理员 // 检查是否为管理员
if user.Role != model.RoleAdmin { if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required") AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return return
} }

View File

@@ -5,11 +5,9 @@ import (
"log" "log"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm"
) )
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件 // NewApiKeyAuthMiddleware 创建 API Key 认证中间件
@@ -61,7 +59,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key // 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return return
} }
@@ -136,28 +134,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 将API key和用户信息存入上下文 // 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey) c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User) c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next() c.Next()
} }
} }
// GetApiKeyFromContext 从上下文中获取API key // GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) { func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey)) value, exists := c.Get(string(ContextKeyApiKey))
if !exists { if !exists {
return nil, false return nil, false
} }
apiKey, ok := value.(*model.ApiKey) apiKey, ok := value.(*service.ApiKey)
return apiKey, ok return apiKey, ok
} }
// GetSubscriptionFromContext 从上下文中获取订阅信息 // GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) { func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription)) value, exists := c.Get(string(ContextKeySubscription))
if !exists { if !exists {
return nil, false return nil, false
} }
subscription, ok := value.(*model.UserSubscription) subscription, ok := value.(*service.UserSubscription)
return subscription, ok return subscription, ok
} }

View File

@@ -0,0 +1,28 @@
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
return return
} }
// 将用户信息存入上下文 c.Set(string(ContextKeyUser), AuthSubject{
c.Set(string(ContextKeyUser), user) UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next() c.Next()
} }
} }
// GetUserFromContext 从上下文中获取用户 // Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}

View File

@@ -8,6 +8,8 @@ type ContextKey string
const ( const (
// ContextKeyUser 用户上下文键 // ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user" ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键 // ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key" ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键 // ContextKeySubscription 订阅上下文键

View File

@@ -0,0 +1,324 @@
package service
import "time"
type Account struct {
ID int64
Name string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool
RateLimitedAt *time.Time
RateLimitResetAt *time.Time
OverloadUntil *time.Time
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
Proxy *Proxy
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true
}
_, exists := mapping[requestedModel]
return exists
}
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com"
}
return baseURL
}
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true
}
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com"
}
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
if v, ok := a.Credentials["expires_at"].(float64); ok {
tt := time.Unix(int64(v), 0)
return &tt
}
return nil
}
return &t
}
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false
}
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -0,0 +1,13 @@
package service
import "time"
type AccountGroup struct {
AccountID int64
GroupID int64
Priority int
CreatedAt time.Time
Account *Account
Group *Group
}

View File

@@ -6,7 +6,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
@@ -15,29 +14,29 @@ var (
) )
type AccountRepository interface { type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error) GetByID(ctx context.Context, id int64) (*Account, error)
// GetByCRSAccountID finds an account previously synced from CRS. // GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found. // Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
Update(ctx context.Context, account *model.Account) error Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]model.Account, error) ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error) ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error
@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
} }
// Create 创建账号 // Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) { func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组) // 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 { if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// 创建账号 // 创建账号
account := &model.Account{ account := &Account{
Name: req.Name, Name: req.Name,
Platform: req.Platform, Platform: req.Platform,
Type: req.Type, Type: req.Type,
@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
Concurrency: req.Concurrency, Concurrency: req.Concurrency,
Priority: req.Priority, Priority: req.Priority,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
} }
// GetByID 根据ID获取账号 // GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
} }
// List 获取账号列表 // List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params) accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err) return nil, nil, fmt.Errorf("list accounts: %w", err)
@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
} }
// ListByPlatform 根据平台获取账号列表 // ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform) accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err) return nil, fmt.Errorf("list accounts by platform: %w", err)
@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
} }
// ListByGroup 根据分组获取账号列表 // ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID) accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err) return nil, fmt.Errorf("list accounts by group: %w", err)
@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
} }
// Update 更新账号 // Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
// 根据平台执行不同的测试逻辑 // 根据平台执行不同的测试逻辑
switch account.Platform { switch account.Platform {
case model.PlatformAnthropic: case PlatformAnthropic:
// TODO: 测试Anthropic API凭证 // TODO: 测试Anthropic API凭证
return nil return nil
case model.PlatformOpenAI: case PlatformOpenAI:
// TODO: 测试OpenAI API凭证 // TODO: 测试OpenAI API凭证
return nil return nil
case model.PlatformGemini: case PlatformGemini:
// TODO: 测试Gemini API凭证 // TODO: 测试Gemini API凭证
return nil return nil
default: default:

View File

@@ -11,11 +11,11 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -23,6 +23,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const ( const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages" testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testOpenAIAPIURL = "https://api.openai.com/v1/responses" testOpenAIAPIURL = "https://api.openai.com/v1/responses"
@@ -141,7 +145,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
} }
// testClaudeAccountConnection tests an Anthropic Claude account's connection // testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Determine the model to use // Determine the model to use
@@ -268,7 +272,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
} }
// testOpenAIAccountConnection tests an OpenAI account's connection // testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error { func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context() ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing // Default to openai.DefaultTestModel for OpenAI testing
@@ -667,11 +671,11 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") { if line == "" || !sseDataPrefix.MatchString(line) {
continue continue
} }
jsonStr := strings.TrimPrefix(line, "data: ") jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" { if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil
@@ -721,11 +725,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
} }
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") { if line == "" || !sseDataPrefix.MatchString(line) {
continue continue
} }
jsonStr := strings.TrimPrefix(line, "data: ") jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" { if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil return nil

View File

@@ -7,24 +7,23 @@ import (
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
) )
type UsageLogRepository interface { type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error Create(ctx context.Context, log *UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error) GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
@@ -44,7 +43,7 @@ type UsageLogRepository interface {
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats // Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats // Account stats
@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
} }
// Setup Token账号根据session_window推算没有profile scope无法调用usage API // Setup Token账号根据session_window推算没有profile scope无法调用usage API
if account.Type == model.AccountTypeSetupToken { if account.Type == AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account) usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计 // 添加窗口统计
s.addWindowStats(ctx, account, usage) s.addWindowStats(ctx, account, usage)
@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
} }
// addWindowStats 为usage数据添加窗口期统计 // addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) { func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
if usage.FiveHour == nil { if usage.FiveHour == nil {
return return
} }
@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
} }
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, fmt.Errorf("no access token available")
@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
} }
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量 // estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo { func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
info := &UsageInfo{} info := &UsageInfo{}
// 如果有session_window信息 // 如果有session_window信息

View File

@@ -7,62 +7,61 @@ import (
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
type AdminService interface { type AdminService interface {
// User management // User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error) GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error) GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error) GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
// Account management // Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error) GetAccount(ctx context.Context, id int64) (*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management // Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error) GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error) GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management // Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
} }
// Input types for admin operations // Input types for admin operations
@@ -252,7 +251,7 @@ func NewAdminService(
} }
// User management implementations // User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil { if err != nil {
@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st
return users, result.Total, nil return users, result.Total, nil
} }
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) { func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id) return s.userRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) { func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &model.User{ user := &User{
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
Wechat: input.Wechat, Wechat: input.Wechat,
Notes: input.Notes, Notes: input.Notes,
Role: "user", // Always create as regular user, never admin Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance, Balance: input.Balance,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Status: model.StatusActive, Status: StatusActive,
AllowedGroups: input.AllowedGroups,
} }
if err := user.SetPassword(input.Password); err != nil { if err := user.SetPassword(input.Password); err != nil {
return nil, err return nil, err
@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
return user, nil return user, nil
} }
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) { func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
concurrencyDiff := user.Concurrency - oldConcurrency concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 { if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode() code, err := GenerateRedeemCode()
if err != nil { if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err) log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil return user, nil
} }
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &RedeemCode{
Code: code, Code: code,
Type: model.AdjustmentTypeAdminConcurrency, Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff), Value: float64(concurrencyDiff),
Status: model.StatusUsed, Status: StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
} }
now := time.Now() now := time.Now()
@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id) return s.userRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) { func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
balanceDiff := user.Balance - oldBalance balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 { if balanceDiff != 0 {
code, err := model.GenerateRedeemCode() code, err := GenerateRedeemCode()
if err != nil { if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err) log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil return user, nil
} }
adjustmentRecord := &model.RedeemCode{ adjustmentRecord := &RedeemCode{
Code: code, Code: code,
Type: model.AdjustmentTypeAdminBalance, Type: AdjustmentTypeAdminBalance,
Value: balanceDiff, Value: balanceDiff,
Status: model.StatusUsed, Status: StatusUsed,
UsedBy: &user.ID, UsedBy: &user.ID,
Notes: notes, Notes: notes,
} }
@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil return user, nil
} }
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
} }
// Group management implementations // Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil { if err != nil {
@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p
return groups, result.Total, nil return groups, result.Total, nil
} }
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) { func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx) return s.groupRepo.ListActive(ctx)
} }
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform) return s.groupRepo.ListActiveByPlatform(ctx, platform)
} }
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) { func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id) return s.groupRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) { func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform platform := input.Platform
if platform == "" { if platform == "" {
platform = model.PlatformAnthropic platform = PlatformAnthropic
} }
subscriptionType := input.SubscriptionType subscriptionType := input.SubscriptionType
if subscriptionType == "" { if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard subscriptionType = SubscriptionTypeStandard
} }
group := &model.Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Platform: platform, Platform: platform,
RateMultiplier: input.RateMultiplier, RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: model.StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD, DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD, WeeklyLimitUSD: input.WeeklyLimitUSD,
@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil return group, nil
} }
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil return nil
} }
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil { if err != nil {
@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
} }
// Account management implementations // Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil { if err != nil {
@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
return accounts, result.Total, nil return accounts, result.Total, nil
} }
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id) return s.accountRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) { func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &model.Account{ account := &Account{
Name: input.Name, Name: input.Name,
Platform: input.Platform, Platform: input.Platform,
Type: input.Type, Type: input.Type,
Credentials: model.JSONB(input.Credentials), Credentials: input.Credentials,
Extra: model.JSONB(input.Extra), Extra: input.Extra,
ProxyID: input.ProxyID, ProxyID: input.ProxyID,
Concurrency: input.Concurrency, Concurrency: input.Concurrency,
Priority: input.Priority, Priority: input.Priority,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.accountRepo.Create(ctx, account); err != nil { if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err return nil, err
@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return account, nil return account, nil
} }
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) { func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Type = input.Type account.Type = input.Type
} }
if len(input.Credentials) > 0 { if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials) account.Credentials = input.Credentials
} }
if len(input.Extra) > 0 { if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra) account.Extra = input.Extra
} }
if input.ProxyID != nil { if input.ProxyID != nil {
account.ProxyID = input.ProxyID account.ProxyID = input.ProxyID
@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id) return s.accountRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
return account, nil return account, nil
} }
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) { func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
account.Status = model.StatusActive account.Status = StatusActive
account.ErrorMessage = "" account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil { if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err return nil, err
@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo
return account, nil return account, nil
} }
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) { func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err return nil, err
} }
@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
} }
// Proxy management implementations // Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil { if err != nil {
@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil return proxies, result.Total, nil
} }
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) { func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx) return s.proxyRepo.ListActive(ctx)
} }
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx) return s.proxyRepo.ListActiveWithAccountCount(ctx)
} }
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) { func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id) return s.proxyRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) { func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &model.Proxy{ proxy := &Proxy{
Name: input.Name, Name: input.Name,
Protocol: input.Protocol, Protocol: input.Protocol,
Host: input.Host, Host: input.Host,
Port: input.Port, Port: input.Port,
Username: input.Username, Username: input.Username,
Password: input.Password, Password: input.Password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err return nil, err
@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
return proxy, nil return proxy, nil
} }
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) { func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id) return s.proxyRepo.Delete(ctx, id)
} }
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) { func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method // Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil return []Account{}, 0, nil
} }
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
} }
// Redeem code management implementations // Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize} params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil { if err != nil {
@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i
return codes, result.Total, nil return codes, result.Total, nil
} }
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id) return s.redeemCodeRepo.GetByID(ctx, id)
} }
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) { func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID // 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription { if input.Type == RedeemTypeSubscription {
if input.GroupID == nil { if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type") return nil, errors.New("group_id is required for subscription type")
} }
@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
} }
} }
codes := make([]model.RedeemCode, 0, input.Count) codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ { for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode() codeValue, err := GenerateRedeemCode()
if err != nil { if err != nil {
return nil, err return nil, err
} }
code := model.RedeemCode{ code := RedeemCode{
Code: codeValue, Code: codeValue,
Type: input.Type, Type: input.Type,
Value: input.Value, Value: input.Value,
Status: model.StatusUnused, Status: StatusUnused,
} }
// 订阅类型专用字段 // 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription { if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 { if code.ValidityDays <= 0 {
@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int
return deleted, nil return deleted, nil
} }
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id) code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
code.Status = model.StatusExpired code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil { if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,20 @@
package service
import "time"
type ApiKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -4,16 +4,13 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
) )
var ( var (
@@ -30,17 +27,17 @@ const (
) )
type ApiKeyRepository interface { type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error) GetByID(ctx context.Context, id int64) (*ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error) GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error) ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error)
} }
@@ -144,7 +141,7 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
} }
count, err := s.cache.GetCreateAttemptCount(ctx, userID) count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil && !errors.Is(err, redis.Nil) { if err != nil {
// Redis 出错时不阻止用户操作 // Redis 出错时不阻止用户操作
return nil return nil
} }
@@ -168,7 +165,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组 // canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅 // 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool { func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +176,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User,
} }
// Create 创建API Key // Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
@@ -235,12 +232,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// 创建API Key记录 // 创建API Key记录
apiKey := &model.ApiKey{ apiKey := &ApiKey{
UserID: userID, UserID: userID,
Key: key, Key: key,
Name: req.Name, Name: req.Name,
GroupID: req.GroupID, GroupID: req.GroupID,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
@@ -251,7 +248,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
} }
// List 获取用户的API Key列表 // List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err) return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -260,7 +257,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
} }
// GetByID 根据ID获取API Key // GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
@@ -269,7 +266,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e
} }
// GetByKey 根据Key字符串获取API Key用于认证 // GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取 // 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key) cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -289,7 +286,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
} }
// Update 更新API Key // Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
@@ -364,7 +361,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
} }
// ValidateKey 验证API Key是否有效用于认证中间件 // ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) { func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key // 获取API Key
apiKey, err := s.GetByKey(ctx, key) apiKey, err := s.GetByKey(ctx, key)
if err != nil { if err != nil {
@@ -408,7 +405,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组: // 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的 // - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) { func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
@@ -434,7 +431,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// 过滤出用户有权限的分组 // 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0) availableGroups := make([]Group, 0)
for _, group := range allGroups { for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) { if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group) availableGroups = append(availableGroups, group)
@@ -445,7 +442,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
} }
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool { func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅 // 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() { if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID] return subscribedGroupIDs[group.ID]
@@ -454,7 +451,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
return user.CanBindGroup(group.ID, group.IsExclusive) return user.CanBindGroup(group.ID, group.IsExclusive)
} }
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil { if err != nil {
return nil, fmt.Errorf("search api keys: %w", err) return nil, fmt.Errorf("search api keys: %w", err)

View File

@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@@ -64,12 +63,12 @@ func NewAuthService(
} }
// Register 用户注册返回token和用户 // Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "") return s.RegisterWithVerification(ctx, email, password, "")
} }
// RegisterWithVerification 用户注册支持邮件验证返回token和用户 // RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) { func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册 // 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled return "", nil, ErrRegDisabled
@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
} }
// 创建用户 // 创建用户
user := &model.User{ user := &User{
Email: email, Email: email,
PasswordHash: hashedPassword, PasswordHash: hashedPassword,
Role: model.RoleUser, Role: RoleUser,
Balance: defaultBalance, Balance: defaultBalance,
Concurrency: defaultConcurrency, Concurrency: defaultConcurrency,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
} }
// Login 用户登录返回JWT token // Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) { func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
// 查找用户 // 查找用户
user, err := s.userRepo.GetByEmail(ctx, email) user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
} }
// GenerateToken 生成JWT token // GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) { func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now() now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
// 错误定义 // 错误定义
@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求 // CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0 // 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入 // 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error { func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 判断计费模式 // 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
} }
// checkSubscriptionEligibility 检查订阅模式资格 // checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error { func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
// 获取订阅缓存数据 // 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil { if err != nil {
@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
} }
// 检查订阅状态 // 检查订阅状态
if subData.Status != model.SubscriptionStatusActive { if subData.Status != SubscriptionStatusActive {
return ErrSubscriptionInvalid return ErrSubscriptionInvalid
} }
@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
} }
// checkSubscriptionLimitsFallback 降级检查订阅限额 // checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error { func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil { if subscription == nil {
return ErrSubscriptionInvalid return ErrSubscriptionInvalid
} }

View File

@@ -12,8 +12,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
type CRSSyncService struct { type CRSSyncService struct {
@@ -217,7 +215,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
), ),
} }
var proxies []model.Proxy var proxies []Proxy
if input.SyncProxies { if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx) proxies, _ = s.proxyRepo.ListActive(ctx)
} }
@@ -234,7 +232,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
if targetType == "" { if targetType == "" {
targetType = "oauth" targetType = "oauth"
} }
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken { if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken {
item.Action = "skipped" item.Action = "skipped"
item.Error = "unsupported authType: " + targetType item.Error = "unsupported authType: " + targetType
result.Skipped++ result.Skipped++
@@ -305,12 +303,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic, Platform: PlatformAnthropic,
Type: targetType, Type: targetType,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
@@ -325,7 +323,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
// 🔄 Refresh OAuth token after creation // 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account) _ = s.accountRepo.Update(ctx, account)
@@ -338,11 +336,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// Update existing // Update existing
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = targetType existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
@@ -360,7 +358,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
// 🔄 Refresh OAuth token after update // 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth { if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing) _ = s.accountRepo.Update(ctx, existing)
@@ -422,12 +420,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic, Platform: PlatformAnthropic,
Type: model.AccountTypeApiKey, Type: AccountTypeApiKey,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
@@ -447,11 +445,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic existing.Platform = PlatformAnthropic
existing.Type = model.AccountTypeApiKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
@@ -545,12 +543,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI, Platform: PlatformOpenAI,
Type: model.AccountTypeOAuth, Type: AccountTypeOAuth,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
@@ -575,11 +573,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = model.AccountTypeOAuth existing.Type = AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
@@ -666,12 +664,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
} }
if existing == nil { if existing == nil {
account := &model.Account{ account := &Account{
Name: defaultName(src.Name, src.ID), Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI, Platform: PlatformOpenAI,
Type: model.AccountTypeApiKey, Type: AccountTypeApiKey,
Credentials: model.JSONB(credentials), Credentials: credentials,
Extra: model.JSONB(extra), Extra: extra,
ProxyID: proxyID, ProxyID: proxyID,
Concurrency: concurrency, Concurrency: concurrency,
Priority: priority, Priority: priority,
@@ -691,11 +689,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue continue
} }
existing.Extra = mergeJSONB(existing.Extra, extra) existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID) existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI existing.Platform = PlatformOpenAI
existing.Type = model.AccountTypeApiKey existing.Type = AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials) existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil { if proxyID != nil {
existing.ProxyID = proxyID existing.ProxyID = proxyID
} }
@@ -939,9 +937,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return result, nil return result, nil
} }
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates. func mergeMap(existing map[string]any, updates map[string]any) map[string]any {
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB { out := make(map[string]any, len(existing)+len(updates))
out := make(model.JSONB)
for k, v := range existing { for k, v := range existing {
out[k] = v out[k] = v
} }
@@ -951,7 +948,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
return out return out
} }
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) { func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil { if !enabled || src == nil {
return nil, nil return nil, nil
} }
@@ -987,14 +984,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac
} }
// Create new proxy // Create new proxy
proxy := &model.Proxy{ proxy := &Proxy{
Name: defaultProxyName(defaultName, protocol, host, port), Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol, Protocol: protocol,
Host: host, Host: host,
Port: port, Port: port,
Username: username, Username: username,
Password: password, Password: password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err return nil, err
@@ -1153,8 +1150,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT
// refreshOAuthToken attempts to refresh OAuth token for a synced account // refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable // Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB { func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any {
if account.Type != model.AccountTypeOAuth { if account.Type != AccountTypeOAuth {
return nil return nil
} }
@@ -1162,7 +1159,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
var err error var err error
switch account.Platform { switch account.Platform {
case model.PlatformAnthropic: case PlatformAnthropic:
if s.oauthService == nil { if s.oauthService == nil {
return nil return nil
} }
@@ -1187,7 +1184,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
newCredentials["scope"] = tokenInfo.Scope newCredentials["scope"] = tokenInfo.Scope
} }
} }
case model.PlatformOpenAI: case PlatformOpenAI:
if s.openaiOAuthService == nil { if s.openaiOAuthService == nil {
return nil return nil
} }
@@ -1227,5 +1224,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
return nil return nil
} }
return model.JSONB(newCredentials) return newCredentials
} }

View File

@@ -0,0 +1,96 @@
package service
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"

View File

@@ -11,7 +11,6 @@ import (
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
) )
var ( var (
@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
// GetSmtpConfig 从数据库获取SMTP配置 // GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{ keys := []string{
model.SettingKeySmtpHost, SettingKeySmtpHost,
model.SettingKeySmtpPort, SettingKeySmtpPort,
model.SettingKeySmtpUsername, SettingKeySmtpUsername,
model.SettingKeySmtpPassword, SettingKeySmtpPassword,
model.SettingKeySmtpFrom, SettingKeySmtpFrom,
model.SettingKeySmtpFromName, SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS, SettingKeySmtpUseTLS,
} }
settings, err := s.settingRepo.GetMultiple(ctx, keys) settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err) return nil, fmt.Errorf("get smtp settings: %w", err)
} }
host := settings[model.SettingKeySmtpHost] host := settings[SettingKeySmtpHost]
if host == "" { if host == "" {
return nil, ErrEmailNotConfigured return nil, ErrEmailNotConfigured
} }
port := 587 // 默认端口 port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" { if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil { if p, err := strconv.Atoi(portStr); err == nil {
port = p port = p
} }
} }
useTLS := settings[model.SettingKeySmtpUseTLS] == "true" useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{ return &SmtpConfig{
Host: host, Host: host,
Port: port, Port: port,
Username: settings[model.SettingKeySmtpUsername], Username: settings[SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword], Password: settings[SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom], From: settings[SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName], FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS, UseTLS: useTLS,
}, nil }, nil
} }

View File

@@ -17,7 +17,6 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -31,6 +30,10 @@ const (
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
) )
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`)
// allowedHeaders 白名单headers参考CRS项目 // allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{ var allowedHeaders = map[string]bool{
"accept": true, "accept": true,
@@ -265,12 +268,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
} }
// SelectAccount 选择账号(粘性会话+优先级) // SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
} }
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. 查询粘性会话 // 1. 查询粘性会话
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -289,19 +292,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
} }
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
// 3. 按优先级+最久未用选择(考虑模型支持) // 3. 按优先级+最久未用选择(考虑模型支持)
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// 检查模型支持 // 检查模型支持
@@ -350,12 +353,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
} }
// GetAccessToken 获取账号凭证 // GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
case model.AccountTypeOAuth, model.AccountTypeSetupToken: case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow // Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account) return s.getOAuthToken(ctx, account)
case model.AccountTypeApiKey: case AccountTypeApiKey:
apiKey := account.GetCredential("api_key") apiKey := account.GetCredential("api_key")
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
@@ -366,7 +369,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
} }
} }
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
@@ -381,10 +384,7 @@ const (
retryDelay = 3 * time.Second // 重试等待时间 retryDelay = 3 * time.Second // 重试等待时间
) )
// shouldRetryUpstreamError 判断是否应该重试上游错误 func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
// API Key 账号:未配置的错误码重试
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试 // OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() { if account.IsOAuth() {
return statusCode == 403 return statusCode == 403
@@ -395,7 +395,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
} }
// Forward 转发请求到Claude API // Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now() startTime := time.Now()
// 解析请求获取model和stream // 解析请求获取model和stream
@@ -421,7 +421,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
// 应用模型映射仅对apikey类型账号 // 应用模型映射仅对apikey类型账号
originalModel := req.Model originalModel := req.Model
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model) mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model { if mappedModel != req.Model {
// 替换请求体中的模型名 // 替换请求体中的模型名
@@ -513,10 +513,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil }, nil
} }
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages" targetURL = baseURL + "/v1/messages"
} }
@@ -640,7 +640,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态 // 处理上游错误,标记账号状态
@@ -695,7 +695,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// handleRetryExhaustedError 处理重试耗尽后的错误 // handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常 // OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号 // API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode statusCode := resp.StatusCode
@@ -726,7 +726,7 @@ type streamingResult struct {
firstTokenMs *int firstTokenMs *int
} }
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -758,8 +758,12 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段 // 如果有模型映射替换响应中的model字段
if needModelReplace && strings.HasPrefix(line, "data: ") { if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) line = s.replaceModelInSSELine(line, mappedModel, originalModel)
} }
@@ -769,15 +773,18 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
} }
flusher.Flush() flusher.Flush()
// 解析usage数据
if strings.HasPrefix(line, "data: ") {
data := line[6:]
// 记录首字时间:第一个有效的 content_block_delta 或 message_start // 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
} }
} }
@@ -790,7 +797,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// replaceModelInSSELine 替换SSE数据行中的model字段 // replaceModelInSSELine 替换SSE数据行中的model字段
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
data := line[6:] // 去掉 "data: " 前缀 if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" { if data == "" || data == "[DONE]" {
return line return line
} }
@@ -865,7 +875,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
} }
} }
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) { func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态 // 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -924,10 +934,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ApiKey *model.ApiKey ApiKey *ApiKey
User *model.User User *User
Account *model.Account Account *Account
Subscription *model.UserSubscription // 可选:订阅信息 Subscription *UserSubscription // 可选:订阅信息
} }
// RecordUsage 记录使用量并扣费(或更新订阅用量) // RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -961,14 +971,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// 判断计费方式:订阅模式 vs 余额模式 // 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance billingType := BillingTypeBalance
if isSubscriptionBilling { if isSubscriptionBilling {
billingType = model.BillingTypeSubscription billingType = BillingTypeSubscription
} }
// 创建使用日志 // 创建使用日志
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,
@@ -1047,9 +1057,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号) // 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
var req struct { var req struct {
Model string `json:"model"` Model string `json:"model"`
} }
@@ -1122,10 +1132,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey { if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL() baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens" targetURL = baseURL + "/v1/messages/count_tokens"
} }

View File

@@ -0,0 +1,48 @@
package service
import "time"
type Group struct {
ID int64
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
}
func (g *Group) IsActive() bool {
return g.Status == StatusActive
}
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
@@ -15,16 +14,16 @@ var (
) )
type GroupRepository interface { type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error) GetByID(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *model.Group) error Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error) DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error) ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error)
@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService {
} }
// Create 创建分组 // Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) { func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
// 检查名称是否已存在 // 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name) exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil { if err != nil {
@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
} }
// 创建分组 // 创建分组
group := &model.Group{ group := &Group{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier, RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive, IsExclusive: req.IsExclusive,
Status: model.StatusActive, Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
} }
// GetByID 根据ID获取分组 // GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
} }
// List 获取分组列表 // List 获取分组列表
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params) groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err) return nil, nil, fmt.Errorf("list groups: %w", err)
@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar
} }
// ListActive 获取活跃分组列表 // ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
groups, err := s.groupRepo.ListActive(ctx) groups, err := s.groupRepo.ListActive(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("list active groups: %w", err) return nil, fmt.Errorf("list active groups: %w", err)
@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
} }
// Update 更新分组 // Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)

View File

@@ -6,7 +6,6 @@ import (
"log" "log"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr
} }
// RefreshAccountToken refreshes token for an account // RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) { func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token") refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" { if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available") return nil, fmt.Errorf("no refresh token available")

View File

@@ -11,12 +11,12 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -28,6 +28,10 @@ const (
openaiStickySessionTTL = time.Hour // 粘性会话TTL openaiStickySessionTTL = time.Hour // 粘性会话TTL
) )
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-OAuth accounts) // OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{ var openaiAllowedHeaders = map[string]bool{
"accept-language": true, "accept-language": true,
@@ -119,12 +123,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
} }
// SelectAccount selects an OpenAI account with sticky session support // SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "") return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
} }
// SelectAccountForModel selects an account supporting the requested model // SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. Check sticky session // 1. Check sticky session
if sessionHash != "" { if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
@@ -139,19 +143,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
// 2. Get schedulable OpenAI accounts // 2. Get schedulable OpenAI accounts
var accounts []model.Account var accounts []Account
var err error var err error
if groupID != nil { if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else { } else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI) accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err) return nil, fmt.Errorf("query accounts failed: %w", err)
} }
// 3. Select by priority + LRU // 3. Select by priority + LRU
var selected *model.Account var selected *Account
for i := range accounts { for i := range accounts {
acc := &accounts[i] acc := &accounts[i]
// Check model support // Check model support
@@ -198,15 +202,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
} }
// GetAccessToken gets the access token for an OpenAI account // GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type { switch account.Type {
case model.AccountTypeOAuth: case AccountTypeOAuth:
accessToken := account.GetOpenAIAccessToken() accessToken := account.GetOpenAIAccessToken()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("access_token not found in credentials") return "", "", errors.New("access_token not found in credentials")
} }
return accessToken, "oauth", nil return accessToken, "oauth", nil
case model.AccountTypeApiKey: case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey() apiKey := account.GetOpenAIApiKey()
if apiKey == "" { if apiKey == "" {
return "", "", errors.New("api_key not found in credentials") return "", "", errors.New("api_key not found in credentials")
@@ -218,7 +222,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
} }
// Forward forwards request to OpenAI API // Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now() startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles) // Parse request body once (avoid multiple parse/serialize cycles)
@@ -243,7 +247,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// For OAuth accounts using ChatGPT internal API, add store: false // For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
reqBody["store"] = false reqBody["store"] = false
bodyModified = true bodyModified = true
} }
@@ -305,7 +309,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// Extract and save Codex usage snapshot from response headers (for OAuth accounts) // Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil { if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
} }
@@ -321,14 +325,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil }, nil
} }
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) { func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type // Determine target URL based on account type
var targetURL string var targetURL string
switch account.Type { switch account.Type {
case model.AccountTypeOAuth: case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API // OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL targetURL = chatgptCodexURL
case model.AccountTypeApiKey: case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL // API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL() baseURL := account.GetOpenAIBaseURL()
if baseURL != "" { if baseURL != "" {
@@ -349,7 +353,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("authorization", "Bearer "+token) req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API) // Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth { if account.Type == AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set) // Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com" req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header // Required: set chatgpt-account-id header
@@ -389,7 +393,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil return req, nil
} }
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) { func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// Check custom error codes // Check custom error codes
@@ -445,7 +449,7 @@ type openaiStreamingResult struct {
firstTokenMs *int firstTokenMs *int
} }
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers // Set SSE response headers
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
@@ -473,8 +477,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed // Replace model in response if needed
if needModelReplace && strings.HasPrefix(line, "data: ") { if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel) line = s.replaceModelInSSELine(line, mappedModel, originalModel)
} }
@@ -484,15 +492,18 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
flusher.Flush() flusher.Flush()
// Parse usage data
if strings.HasPrefix(line, "data: ") {
data := line[6:]
// Record first token time // Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" { if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
} }
s.parseSSEUsage(data, usage) s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
} }
} }
@@ -504,7 +515,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
data := line[6:] if !openaiSSEDataRe.MatchString(line) {
return line
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" { if data == "" || data == "[DONE]" {
return line return line
} }
@@ -561,7 +575,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
} }
} }
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -627,10 +641,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage // OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct { type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult Result *OpenAIForwardResult
ApiKey *model.ApiKey ApiKey *ApiKey
User *model.User User *User
Account *model.Account Account *Account
Subscription *model.UserSubscription Subscription *UserSubscription
} }
// RecordUsage records usage and deducts balance // RecordUsage records usage and deducts balance
@@ -669,14 +683,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Determine billing type // Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance billingType := BillingTypeBalance
if isSubscriptionBilling { if isSubscriptionBilling {
billingType = model.BillingTypeSubscription billingType = BillingTypeSubscription
} }
// Create usage log // Create usage log
durationMs := int(result.Duration.Milliseconds()) durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{ usageLog := &UsageLog{
UserID: user.ID, UserID: user.ID,
ApiKeyID: apiKey.ID, ApiKeyID: apiKey.ID,
AccountID: account.ID, AccountID: account.ID,

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
) )
@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
} }
// RefreshAccountToken refreshes token for an OpenAI account // RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) { func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() { if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account") return nil, fmt.Errorf("account is not an OpenAI account")
} }

View File

@@ -0,0 +1,35 @@
package service
import (
"fmt"
"time"
)
type Proxy struct {
ID int64
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
CreatedAt time.Time
UpdatedAt time.Time
}
func (p *Proxy) IsActive() bool {
return p.Status == StatusActive
}
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
@@ -14,15 +13,15 @@ var (
) )
type ProxyRepository interface { type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error Create(ctx context.Context, proxy *Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error) GetByID(ctx context.Context, id int64) (*Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error Update(ctx context.Context, proxy *Proxy) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error) ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService {
} }
// Create 创建代理 // Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) {
// 创建代理 // 创建代理
proxy := &model.Proxy{ proxy := &Proxy{
Name: req.Name, Name: req.Name,
Protocol: req.Protocol, Protocol: req.Protocol,
Host: req.Host, Host: req.Host,
Port: req.Port, Port: req.Port,
Username: req.Username, Username: req.Username,
Password: req.Password, Password: req.Password,
Status: model.StatusActive, Status: StatusActive,
} }
if err := s.proxyRepo.Create(ctx, proxy); err != nil { if err := s.proxyRepo.Create(ctx, proxy); err != nil {
@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
} }
// GetByID 根据ID获取代理 // GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
} }
// List 获取代理列表 // List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params) proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err) return nil, nil, fmt.Errorf("list proxies: %w", err)
@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar
} }
// ListActive 获取活跃代理列表 // ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx) proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err) return nil, fmt.Errorf("list active proxies: %w", err)
@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
} }
// Update 更新代理 // Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)

Some files were not shown because too many files have changed in this diff Show More