diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e72e5f6e..827526e5 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) - authHandler := handler.NewAuthHandler(authService) userService := service.NewUserService(userRepository) + authHandler := handler.NewAuthHandler(authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(db) groupRepository := repository.NewGroupRepository(db) diff --git a/backend/go.mod b/backend/go.mod index faf196b7..d8e04646 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -22,12 +22,14 @@ require ( golang.org/x/net v0.47.0 golang.org/x/term v0.37.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/datatypes v1.2.0 gorm.io/driver/postgres v1.5.4 gorm.io/gorm v1.25.5 ) require ( dario.cat/mergo v1.0.2 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect @@ -57,6 +59,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/subcommands v1.2.0 // indirect @@ -64,8 +67,8 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/icholy/digest v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.4 // indirect + github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -132,4 +135,5 @@ require ( google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + gorm.io/driver/mysql v1.5.2 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index ac083b5d..b07e9bca 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,5 +1,7 @@ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= @@ -77,10 +79,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= +github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -104,10 +113,10 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo= github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= -github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= +github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -135,8 +144,12 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= +github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE= +github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -319,8 +332,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco= +gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04= +gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= +gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= +gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= +gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0= +gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 8ecb4326..0522bc3c 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -3,7 +3,7 @@ package admin import ( "strconv" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -102,7 +102,7 @@ type BulkUpdateAccountsRequest struct { // AccountWithConcurrency extends Account with real-time concurrency info type AccountWithConcurrency struct { - *model.Account + *dto.Account CurrentConcurrency int `json:"current_concurrency"` } @@ -137,7 +137,7 @@ func (h *AccountHandler) List(c *gin.Context) { result := make([]AccountWithConcurrency, len(accounts)) for i := range accounts { result[i] = AccountWithConcurrency{ - Account: &accounts[i], + Account: dto.AccountFromService(&accounts[i]), CurrentConcurrency: concurrencyCounts[accounts[i].ID], } } @@ -160,7 +160,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) { return } - response.Success(c, account) + response.Success(c, dto.AccountFromService(account)) } // Create handles creating a new account @@ -188,7 +188,7 @@ func (h *AccountHandler) Create(c *gin.Context) { return } - response.Success(c, account) + response.Success(c, dto.AccountFromService(account)) } // Update handles updating an account @@ -222,7 +222,7 @@ func (h *AccountHandler) Update(c *gin.Context) { return } - response.Success(c, account) + response.Success(c, dto.AccountFromService(account)) } // Delete handles deleting an account @@ -425,7 +425,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { return } - response.Success(c, account) + response.Success(c, dto.AccountFromService(account)) } // BatchCreate handles batch creating accounts @@ -801,7 +801,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { return } - response.Success(c, account) + response.Success(c, dto.AccountFromService(account)) } // GetAvailableModels handles getting available models for an account diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 3fdba7c5..a7dc6c4e 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,11 +1,12 @@ package admin import ( + "strconv" + "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" - "strconv" - "time" "github.com/gin-gonic/gin" ) diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 26d0715f..968d5db2 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -3,7 +3,7 @@ package admin import ( "strconv" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) { return } - response.Paginated(c, groups, total, page, pageSize) + outGroups := make([]dto.Group, 0, len(groups)) + for i := range groups { + outGroups = append(outGroups, *dto.GroupFromService(&groups[i])) + } + response.Paginated(c, outGroups, total, page, pageSize) } // GetAll handles getting all active groups without pagination @@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) { func (h *GroupHandler) GetAll(c *gin.Context) { platform := c.Query("platform") - var groups []model.Group + var groups []service.Group var err error if platform != "" { @@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) { return } - response.Success(c, groups) + outGroups := make([]dto.Group, 0, len(groups)) + for i := range groups { + outGroups = append(outGroups, *dto.GroupFromService(&groups[i])) + } + response.Success(c, outGroups) } // GetByID handles getting a group by ID @@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) { return } - response.Success(c, group) + response.Success(c, dto.GroupFromService(group)) } // Create handles creating a new group @@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) { return } - response.Success(c, group) + response.Success(c, dto.GroupFromService(group)) } // Update handles updating a group @@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) { return } - response.Success(c, group) + response.Success(c, dto.GroupFromService(group)) } // Delete handles deleting a group @@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { return } - response.Paginated(c, keys, total, page, pageSize) + outKeys := make([]dto.ApiKey, 0, len(keys)) + for i := range keys { + outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i])) + } + response.Paginated(c, outKeys, total, page, pageSize) } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 3cdd6a9d..14b569de 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,7 +1,7 @@ package admin import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { return } - response.Success(c, settings) + response.Success(c, dto.SystemSettings{ + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + SmtpHost: settings.SmtpHost, + SmtpPort: settings.SmtpPort, + SmtpUsername: settings.SmtpUsername, + SmtpPassword: settings.SmtpPassword, + SmtpFrom: settings.SmtpFrom, + SmtpFromName: settings.SmtpFromName, + SmtpUseTLS: settings.SmtpUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKey: settings.TurnstileSecretKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + ApiBaseUrl: settings.ApiBaseUrl, + ContactInfo: settings.ContactInfo, + DocUrl: settings.DocUrl, + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + }) } // UpdateSettingsRequest 更新设置请求 @@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SmtpPort = 587 } - settings := &model.SystemSettings{ + settings := &service.SystemSettings{ RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, SmtpHost: req.SmtpHost, @@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { return } - response.Success(c, updatedSettings) + response.Success(c, dto.SystemSettings{ + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + SmtpHost: updatedSettings.SmtpHost, + SmtpPort: updatedSettings.SmtpPort, + SmtpUsername: updatedSettings.SmtpUsername, + SmtpPassword: updatedSettings.SmtpPassword, + SmtpFrom: updatedSettings.SmtpFrom, + SmtpFromName: updatedSettings.SmtpFromName, + SmtpUseTLS: updatedSettings.SmtpUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKey: updatedSettings.TurnstileSecretKey, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + ApiBaseUrl: updatedSettings.ApiBaseUrl, + ContactInfo: updatedSettings.ContactInfo, + DocUrl: updatedSettings.DocUrl, + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + }) } // TestSmtpRequest 测试SMTP连接请求 diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index d101a6e6..c929b75f 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -3,9 +3,10 @@ package admin import ( "strconv" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) { return } - response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) + out := make([]dto.UserSubscription, 0, len(subscriptions)) + for i := range subscriptions { + out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + } + response.PaginatedWithResult(c, out, toResponsePagination(pagination)) } // GetByID handles getting a subscription by ID @@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) { return } - response.Success(c, subscription) + response.Success(c, dto.UserSubscriptionFromService(subscription)) } // GetProgress handles getting subscription usage progress @@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) { return } - response.Success(c, subscription) + response.Success(c, dto.UserSubscriptionFromService(subscription)) } // BulkAssign handles bulk assigning subscriptions to multiple users @@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - response.Success(c, subscription) + response.Success(c, dto.UserSubscriptionFromService(subscription)) } // Revoke handles revoking a subscription @@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) { return } - response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination)) + out := make([]dto.UserSubscription, 0, len(subscriptions)) + for i := range subscriptions { + out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + } + response.PaginatedWithResult(c, out, toResponsePagination(pagination)) } // ListByUser handles listing subscriptions for a specific user @@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) { return } - response.Success(c, subscriptions) + out := make([]dto.UserSubscription, 0, len(subscriptions)) + for i := range subscriptions { + out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + } + response.Success(c, out) } // Helper function to get admin ID from context func getAdminIDFromContext(c *gin.Context) int64 { - if user, exists := c.Get("user"); exists { - if u, ok := user.(*model.User); ok && u != nil { - return u.ID - } + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + return 0 } - return 0 + return subject.UserID } diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 8592e38b..790f4ac2 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -3,9 +3,10 @@ package handler import ( "strconv" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct { // List handles listing user's API keys with pagination // GET /api/v1/api-keys func (h *APIKeyHandler) List(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } page, pageSize := response.ParsePagination(c) params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) + keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params) if err != nil { response.ErrorFrom(c, err) return } - response.Paginated(c, keys, result.Total, page, pageSize) + out := make([]dto.ApiKey, 0, len(keys)) + for i := range keys { + out = append(out, *dto.ApiKeyFromService(&keys[i])) + } + response.Paginated(c, out, result.Total, page, pageSize) } // GetByID handles getting a single API key // GET /api/v1/api-keys/:id func (h *APIKeyHandler) GetByID(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) { } // 验证所有权 - if key.UserID != user.ID { + if key.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this key") return } - response.Success(c, key) + response.Success(c, dto.ApiKeyFromService(key)) } // Create handles creating a new API key // POST /api/v1/api-keys func (h *APIKeyHandler) Create(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) { GroupID: req.GroupID, CustomKey: req.CustomKey, } - key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) + key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, key) + response.Success(c, dto.ApiKeyFromService(key)) } // Update handles updating an API key // PUT /api/v1/api-keys/:id func (h *APIKeyHandler) Update(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { svcReq.Status = &req.Status } - key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) + key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, key) + response.Success(c, dto.ApiKeyFromService(key)) } // Delete handles deleting an API key // DELETE /api/v1/api-keys/:id func (h *APIKeyHandler) Delete(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) { return } - err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) + err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID) if err != nil { response.ErrorFrom(c, err) return @@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) { // GetAvailableGroups 获取用户可以绑定的分组列表 // GET /api/v1/groups/available func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not authenticated") return } - user, ok := userValue.(*model.User) - if !ok { - response.InternalError(c, "Invalid user context") - return - } - - groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID) + groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, groups) + out := make([]dto.Group, 0, len(groups)) + for i := range groups { + out = append(out, *dto.GroupFromService(&groups[i])) + } + response.Success(c, out) } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index efb7584d..799d63d8 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,8 +1,9 @@ package handler import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -11,12 +12,14 @@ import ( // AuthHandler handles authentication-related requests type AuthHandler struct { authService *service.AuthService + userService *service.UserService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(authService *service.AuthService) *AuthHandler { +func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler { return &AuthHandler{ authService: authService, + userService: userService, } } @@ -49,9 +52,9 @@ type LoginRequest struct { // AuthResponse 认证响应格式(匹配前端期望) type AuthResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - User *model.User `json:"user"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + User *dto.User `json:"user"` } // Register handles user registration @@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) { response.Success(c, AuthResponse{ AccessToken: token, TokenType: "Bearer", - User: user, + User: dto.UserFromService(user), }) } @@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) { response.Success(c, AuthResponse{ AccessToken: token, TokenType: "Bearer", - User: user, + User: dto.UserFromService(user), }) } // GetCurrentUser handles getting current authenticated user // GET /api/v1/auth/me func (h *AuthHandler) GetCurrentUser(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not authenticated") return } - user, ok := userValue.(*model.User) - if !ok { - response.InternalError(c, "Invalid user context") + user, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) return } - response.Success(c, user) + response.Success(c, dto.UserFromService(user)) } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go new file mode 100644 index 00000000..7cdfa203 --- /dev/null +++ b/backend/internal/handler/dto/mappers.go @@ -0,0 +1,294 @@ +package dto + +import "github.com/Wei-Shaw/sub2api/internal/service" + +func UserFromServiceShallow(u *service.User) *User { + if u == nil { + return nil + } + return &User{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + Wechat: u.Wechat, + Notes: u.Notes, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + AllowedGroups: u.AllowedGroups, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + } +} + +func UserFromService(u *service.User) *User { + if u == nil { + return nil + } + out := UserFromServiceShallow(u) + if len(u.ApiKeys) > 0 { + out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys)) + for i := range u.ApiKeys { + k := u.ApiKeys[i] + out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k)) + } + } + if len(u.Subscriptions) > 0 { + out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions)) + for i := range u.Subscriptions { + s := u.Subscriptions[i] + out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s)) + } + } + return out +} + +func ApiKeyFromService(k *service.ApiKey) *ApiKey { + if k == nil { + return nil + } + return &ApiKey{ + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), + } +} + +func GroupFromServiceShallow(g *service.Group) *Group { + if g == nil { + return nil + } + return &Group{ + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, + AccountCount: g.AccountCount, + } +} + +func GroupFromService(g *service.Group) *Group { + if g == nil { + return nil + } + out := GroupFromServiceShallow(g) + if len(g.AccountGroups) > 0 { + out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) + for i := range g.AccountGroups { + ag := g.AccountGroups[i] + out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag)) + } + } + return out +} + +func AccountFromServiceShallow(a *service.Account) *Account { + if a == nil { + return nil + } + return &Account{ + ID: a.ID, + Name: a.Name, + Platform: a.Platform, + Type: a.Type, + Credentials: a.Credentials, + Extra: a.Extra, + ProxyID: a.ProxyID, + Concurrency: a.Concurrency, + Priority: a.Priority, + Status: a.Status, + ErrorMessage: a.ErrorMessage, + LastUsedAt: a.LastUsedAt, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, + Schedulable: a.Schedulable, + RateLimitedAt: a.RateLimitedAt, + RateLimitResetAt: a.RateLimitResetAt, + OverloadUntil: a.OverloadUntil, + SessionWindowStart: a.SessionWindowStart, + SessionWindowEnd: a.SessionWindowEnd, + SessionWindowStatus: a.SessionWindowStatus, + GroupIDs: a.GroupIDs, + } +} + +func AccountFromService(a *service.Account) *Account { + if a == nil { + return nil + } + out := AccountFromServiceShallow(a) + out.Proxy = ProxyFromService(a.Proxy) + if len(a.AccountGroups) > 0 { + out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups)) + for i := range a.AccountGroups { + ag := a.AccountGroups[i] + out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag)) + } + } + if len(a.Groups) > 0 { + out.Groups = make([]*Group, 0, len(a.Groups)) + for _, g := range a.Groups { + out.Groups = append(out.Groups, GroupFromServiceShallow(g)) + } + } + return out +} + +func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup { + if ag == nil { + return nil + } + return &AccountGroup{ + AccountID: ag.AccountID, + GroupID: ag.GroupID, + Priority: ag.Priority, + CreatedAt: ag.CreatedAt, + Account: AccountFromServiceShallow(ag.Account), + Group: GroupFromServiceShallow(ag.Group), + } +} + +func ProxyFromService(p *service.Proxy) *Proxy { + if p == nil { + return nil + } + return &Proxy{ + ID: p.ID, + Name: p.Name, + Protocol: p.Protocol, + Host: p.Host, + Port: p.Port, + Username: p.Username, + Password: p.Password, + Status: p.Status, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + } +} + +func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount { + if p == nil { + return nil + } + return &ProxyWithAccountCount{ + Proxy: *ProxyFromService(&p.Proxy), + AccountCount: p.AccountCount, + } +} + +func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode { + if rc == nil { + return nil + } + return &RedeemCode{ + ID: rc.ID, + Code: rc.Code, + Type: rc.Type, + Value: rc.Value, + Status: rc.Status, + UsedBy: rc.UsedBy, + UsedAt: rc.UsedAt, + Notes: rc.Notes, + CreatedAt: rc.CreatedAt, + GroupID: rc.GroupID, + ValidityDays: rc.ValidityDays, + User: UserFromServiceShallow(rc.User), + Group: GroupFromServiceShallow(rc.Group), + } +} + +func UsageLogFromService(l *service.UsageLog) *UsageLog { + if l == nil { + return nil + } + return &UsageLog{ + ID: l.ID, + UserID: l.UserID, + ApiKeyID: l.ApiKeyID, + AccountID: l.AccountID, + RequestID: l.RequestID, + Model: l.Model, + GroupID: l.GroupID, + SubscriptionID: l.SubscriptionID, + InputTokens: l.InputTokens, + OutputTokens: l.OutputTokens, + CacheCreationTokens: l.CacheCreationTokens, + CacheReadTokens: l.CacheReadTokens, + CacheCreation5mTokens: l.CacheCreation5mTokens, + CacheCreation1hTokens: l.CacheCreation1hTokens, + InputCost: l.InputCost, + OutputCost: l.OutputCost, + CacheCreationCost: l.CacheCreationCost, + CacheReadCost: l.CacheReadCost, + TotalCost: l.TotalCost, + ActualCost: l.ActualCost, + RateMultiplier: l.RateMultiplier, + BillingType: l.BillingType, + Stream: l.Stream, + DurationMs: l.DurationMs, + FirstTokenMs: l.FirstTokenMs, + CreatedAt: l.CreatedAt, + User: UserFromServiceShallow(l.User), + ApiKey: ApiKeyFromService(l.ApiKey), + Account: AccountFromService(l.Account), + Group: GroupFromServiceShallow(l.Group), + Subscription: UserSubscriptionFromService(l.Subscription), + } +} + +func SettingFromService(s *service.Setting) *Setting { + if s == nil { + return nil + } + return &Setting{ + ID: s.ID, + Key: s.Key, + Value: s.Value, + UpdatedAt: s.UpdatedAt, + } +} + +func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription { + if sub == nil { + return nil + } + return &UserSubscription{ + ID: sub.ID, + UserID: sub.UserID, + GroupID: sub.GroupID, + StartsAt: sub.StartsAt, + ExpiresAt: sub.ExpiresAt, + Status: sub.Status, + DailyWindowStart: sub.DailyWindowStart, + WeeklyWindowStart: sub.WeeklyWindowStart, + MonthlyWindowStart: sub.MonthlyWindowStart, + DailyUsageUSD: sub.DailyUsageUSD, + WeeklyUsageUSD: sub.WeeklyUsageUSD, + MonthlyUsageUSD: sub.MonthlyUsageUSD, + AssignedBy: sub.AssignedBy, + AssignedAt: sub.AssignedAt, + Notes: sub.Notes, + CreatedAt: sub.CreatedAt, + UpdatedAt: sub.UpdatedAt, + User: UserFromServiceShallow(sub.User), + Group: GroupFromServiceShallow(sub.Group), + AssignedByUser: UserFromServiceShallow(sub.AssignedByUser), + } +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go new file mode 100644 index 00000000..96e59e3f --- /dev/null +++ b/backend/internal/handler/dto/settings.go @@ -0,0 +1,43 @@ +package dto + +// SystemSettings represents the admin settings API response payload. +type SystemSettings struct { + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + + SmtpHost string `json:"smtp_host"` + SmtpPort int `json:"smtp_port"` + SmtpUsername string `json:"smtp_username"` + SmtpPassword string `json:"smtp_password,omitempty"` + SmtpFrom string `json:"smtp_from_email"` + SmtpFromName string `json:"smtp_from_name"` + SmtpUseTLS bool `json:"smtp_use_tls"` + + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` + + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + ApiBaseUrl string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocUrl string `json:"doc_url"` + + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` +} + +type PublicSettings struct { + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + ApiBaseUrl string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocUrl string `json:"doc_url"` + Version string `json:"version"` +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go new file mode 100644 index 00000000..2b59e38d --- /dev/null +++ b/backend/internal/handler/dto/types.go @@ -0,0 +1,212 @@ +package dto + +import "time" + +type User struct { + ID int64 `json:"id"` + Email string `json:"email"` + Username string `json:"username"` + Wechat string `json:"wechat"` + Notes string `json:"notes"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + Status string `json:"status"` + AllowedGroups []int64 `json:"allowed_groups"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + ApiKeys []ApiKey `json:"api_keys,omitempty"` + Subscriptions []UserSubscription `json:"subscriptions,omitempty"` +} + +type ApiKey struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + User *User `json:"user,omitempty"` + Group *Group `json:"group,omitempty"` +} + +type Group struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` + Status string `json:"status"` + + SubscriptionType string `json:"subscription_type"` + DailyLimitUSD *float64 `json:"daily_limit_usd"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` +} + +type Account struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Type string `json:"type"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + Priority int `json:"priority"` + Status string `json:"status"` + ErrorMessage string `json:"error_message"` + LastUsedAt *time.Time `json:"last_used_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + Schedulable bool `json:"schedulable"` + + RateLimitedAt *time.Time `json:"rate_limited_at"` + RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` + OverloadUntil *time.Time `json:"overload_until"` + + SessionWindowStart *time.Time `json:"session_window_start"` + SessionWindowEnd *time.Time `json:"session_window_end"` + SessionWindowStatus string `json:"session_window_status"` + + Proxy *Proxy `json:"proxy,omitempty"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + + GroupIDs []int64 `json:"group_ids,omitempty"` + Groups []*Group `json:"groups,omitempty"` +} + +type AccountGroup struct { + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Priority int `json:"priority"` + CreatedAt time.Time `json:"created_at"` + + Account *Account `json:"account,omitempty"` + Group *Group `json:"group,omitempty"` +} + +type Proxy struct { + ID int64 `json:"id"` + Name string `json:"name"` + Protocol string `json:"protocol"` + Host string `json:"host"` + Port int `json:"port"` + Username string `json:"username"` + Password string `json:"-"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +type ProxyWithAccountCount struct { + Proxy + AccountCount int64 `json:"account_count"` +} + +type RedeemCode struct { + ID int64 `json:"id"` + Code string `json:"code"` + Type string `json:"type"` + Value float64 `json:"value"` + Status string `json:"status"` + UsedBy *int64 `json:"used_by"` + UsedAt *time.Time `json:"used_at"` + Notes string `json:"notes"` + CreatedAt time.Time `json:"created_at"` + + GroupID *int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` + + User *User `json:"user,omitempty"` + Group *Group `json:"group,omitempty"` +} + +type UsageLog struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + ApiKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + RequestID string `json:"request_id"` + Model string `json:"model"` + + GroupID *int64 `json:"group_id"` + SubscriptionID *int64 `json:"subscription_id"` + + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationTokens int `json:"cache_creation_tokens"` + CacheReadTokens int `json:"cache_read_tokens"` + + CacheCreation5mTokens int `json:"cache_creation_5m_tokens"` + CacheCreation1hTokens int `json:"cache_creation_1h_tokens"` + + InputCost float64 `json:"input_cost"` + OutputCost float64 `json:"output_cost"` + CacheCreationCost float64 `json:"cache_creation_cost"` + CacheReadCost float64 `json:"cache_read_cost"` + TotalCost float64 `json:"total_cost"` + ActualCost float64 `json:"actual_cost"` + RateMultiplier float64 `json:"rate_multiplier"` + + BillingType int8 `json:"billing_type"` + Stream bool `json:"stream"` + DurationMs *int `json:"duration_ms"` + FirstTokenMs *int `json:"first_token_ms"` + + CreatedAt time.Time `json:"created_at"` + + User *User `json:"user,omitempty"` + ApiKey *ApiKey `json:"api_key,omitempty"` + Account *Account `json:"account,omitempty"` + Group *Group `json:"group,omitempty"` + Subscription *UserSubscription `json:"subscription,omitempty"` +} + +type Setting struct { + ID int64 `json:"id"` + Key string `json:"key"` + Value string `json:"value"` + UpdatedAt time.Time `json:"updated_at"` +} + +type UserSubscription struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + GroupID int64 `json:"group_id"` + + StartsAt time.Time `json:"starts_at"` + ExpiresAt time.Time `json:"expires_at"` + Status string `json:"status"` + + DailyWindowStart *time.Time `json:"daily_window_start"` + WeeklyWindowStart *time.Time `json:"weekly_window_start"` + MonthlyWindowStart *time.Time `json:"monthly_window_start"` + + DailyUsageUSD float64 `json:"daily_usage_usd"` + WeeklyUsageUSD float64 `json:"weekly_usage_usd"` + MonthlyUsageUSD float64 `json:"monthly_usage_usd"` + + AssignedBy *int64 `json:"assigned_by"` + AssignedAt time.Time `json:"assigned_at"` + Notes string `json:"notes"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + User *User `json:"user,omitempty"` + Group *Group `json:"group,omitempty"` + AssignedByUser *User `json:"assigned_by_user,omitempty"` +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index bfb8b6fd..b281aef2 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -47,7 +46,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - user, ok := middleware2.GetUserFromContext(c) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return @@ -82,8 +81,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { subscription, _ := middleware2.GetSubscriptionFromContext(c) // 0. 检查wait队列是否已满 - maxWait := service.CalculateMaxWait(user.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) if err != nil { log.Printf("Increment wait count failed: %v", err) // On error, allow request to proceed @@ -92,10 +91,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } // 确保在函数退出时减少wait计数 - defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) + defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) // 1. 首先获取用户并发槽位 - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted) + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted) if err != nil { log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted) @@ -106,7 +105,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 2. 【新增】Wait后二次检查余额/订阅 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { log.Printf("Billing eligibility check failed after wait: %v", err) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) return @@ -133,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 3. 获取账号并发槽位 - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -158,7 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, ApiKey: apiKey, - User: user, + User: apiKey.User, Account: account, Subscription: subscription, }); err != nil { @@ -198,7 +197,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } - user, ok := middleware2.GetUserFromContext(c) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") return @@ -223,7 +222,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { } // 余额模式:返回钱包余额 - latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID) + latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) if err != nil { h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") return @@ -241,7 +240,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) { // 逻辑: // 1. 如果日/周/月任一限额达到100%,返回0 // 2. 否则返回所有已配置周期中剩余额度的最小值 -func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 { +func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 { var remainingValues []float64 // 检查日限额 @@ -334,7 +333,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - user, ok := middleware2.GetUserFromContext(c) + _, ok = middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return @@ -366,7 +365,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error()) return } diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 381108e9..5cbe462d 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -6,7 +6,6 @@ import ( "net/http" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64 // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. -func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency) + result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { return nil, err } @@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model. } // Need to wait - handle streaming ping if needed - return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted) + return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted) } // AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. -func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency) + result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } @@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account * } // Need to wait - handle streaming ping if needed - return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted) + return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted) } // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index d1956eca..b082d727 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - user, ok := middleware2.GetUserFromContext(c) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return @@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { subscription, _ := middleware2.GetSubscriptionFromContext(c) // 0. Check if wait queue is full - maxWait := service.CalculateMaxWait(user.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) if err != nil { log.Printf("Increment wait count failed: %v", err) // On error, allow request to proceed @@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } // Ensure wait count is decremented when function exits - defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID) + defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) // 1. First acquire user concurrency slot - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted) + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted) @@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // 2. Re-check billing eligibility after wait - if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { log.Printf("Billing eligibility check failed after wait: %v", err) h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted) return @@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) // 3. Acquire account concurrency slot - accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted) + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) @@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, ApiKey: apiKey, - User: user, + User: apiKey.User, Account: account, Subscription: subscription, }); err != nil { diff --git a/backend/internal/handler/redeem_handler.go b/backend/internal/handler/redeem_handler.go index 765d2e26..1b63f418 100644 --- a/backend/internal/handler/redeem_handler.go +++ b/backend/internal/handler/redeem_handler.go @@ -1,8 +1,9 @@ package handler import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -37,15 +38,9 @@ type RedeemResponse struct { // Redeem handles redeeming a code // POST /api/v1/redeem func (h *RedeemHandler) Redeem(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) { return } - result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) + result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, result) + response.Success(c, dto.RedeemCodeFromService(result)) } // GetHistory returns the user's redemption history // GET /api/v1/redeem/history func (h *RedeemHandler) GetHistory(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } // Default limit is 25 limit := 25 - codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) + codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, codes) + out := make([]dto.RedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromService(&codes[i])) + } + response.Success(c, out) } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index d9804865..90165288 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -1,6 +1,7 @@ package handler import ( + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { return } - settings.Version = h.version - response.Success(c, settings) + response.Success(c, dto.PublicSettings{ + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + ApiBaseUrl: settings.ApiBaseUrl, + ContactInfo: settings.ContactInfo, + DocUrl: settings.DocUrl, + Version: h.version, + }) } diff --git a/backend/internal/handler/subscription_handler.go b/backend/internal/handler/subscription_handler.go index fd67e529..b40df833 100644 --- a/backend/internal/handler/subscription_handler.go +++ b/backend/internal/handler/subscription_handler.go @@ -1,8 +1,9 @@ package handler import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct { // SubscriptionProgressInfo represents subscription with progress info type SubscriptionProgressInfo struct { - Subscription *model.UserSubscription `json:"subscription"` + Subscription *dto.UserSubscription `json:"subscription"` Progress *service.SubscriptionProgress `json:"progress"` } @@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S // List handles listing current user's subscriptions // GET /api/v1/subscriptions func (h *SubscriptionHandler) List(c *gin.Context) { - user, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not found in context") return } - u, ok := user.(*model.User) - if !ok { - response.InternalError(c, "Invalid user in context") - return - } - - subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) + subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, subscriptions) + out := make([]dto.UserSubscription, 0, len(subscriptions)) + for i := range subscriptions { + out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + } + response.Success(c, out) } // GetActive handles getting current user's active subscriptions // GET /api/v1/subscriptions/active func (h *SubscriptionHandler) GetActive(c *gin.Context) { - user, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not found in context") return } - u, ok := user.(*model.User) - if !ok { - response.InternalError(c, "Invalid user in context") - return - } - - subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) + subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, subscriptions) + out := make([]dto.UserSubscription, 0, len(subscriptions)) + for i := range subscriptions { + out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + } + response.Success(c, out) } // GetProgress handles getting subscription progress for current user // GET /api/v1/subscriptions/progress func (h *SubscriptionHandler) GetProgress(c *gin.Context) { - user, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not found in context") return } - u, ok := user.(*model.User) - if !ok { - response.InternalError(c, "Invalid user in context") - return - } - // Get all active subscriptions with progress - subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) + subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return @@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { continue } result = append(result, SubscriptionProgressInfo{ - Subscription: sub, + Subscription: dto.UserSubscriptionFromService(sub), Progress: progress, }) } @@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { // GetSummary handles getting a summary of current user's subscription status // GET /api/v1/subscriptions/summary func (h *SubscriptionHandler) GetSummary(c *gin.Context) { - user, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not found in context") return } - u, ok := user.(*model.User) - if !ok { - response.InternalError(c, "Invalid user in context") - return - } - // Get all active subscriptions - subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) + subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index d73df209..dd8340e7 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -4,10 +4,11 @@ import ( "strconv" "time" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service. // List handles listing usage records with pagination // GET /api/v1/usage func (h *UsageHandler) List(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) { response.ErrorFrom(c, err) return } - if apiKey.UserID != user.ID { + if apiKey.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this API key's usage records") return } @@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) { } params := pagination.PaginationParams{Page: page, PageSize: pageSize} - var records []model.UsageLog + var records []service.UsageLog var result *pagination.PaginationResult var err error if apiKeyID > 0 { records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params) } else { - records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) + records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params) } if err != nil { response.ErrorFrom(c, err) return } - response.Paginated(c, records, result.Total, page, pageSize) + out := make([]dto.UsageLog, 0, len(records)) + for i := range records { + out = append(out, *dto.UsageLogFromService(&records[i])) + } + response.Paginated(c, out, result.Total, page, pageSize) } // GetByID handles getting a single usage record // GET /api/v1/usage/:id func (h *UsageHandler) GetByID(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) { } // 验证所有权 - if record.UserID != user.ID { + if record.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this record") return } - response.Success(c, record) + response.Success(c, dto.UsageLogFromService(record)) } // Stats handles getting usage statistics // GET /api/v1/usage/stats func (h *UsageHandler) Stats(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { response.NotFound(c, "API key not found") return } - if apiKey.UserID != user.ID { + if apiKey.UserID != subject.UserID { response.Forbidden(c, "Not authorized to access this API key's statistics") return } @@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { if apiKeyID > 0 { stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) } else { - stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) + stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime) } if err != nil { response.ErrorFrom(c, err) @@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { // DashboardStats handles getting user dashboard statistics // GET /api/v1/usage/dashboard/stats func (h *UsageHandler) DashboardStats(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not authenticated") return } - user, ok := userValue.(*model.User) - if !ok { - response.InternalError(c, "Invalid user context") - return - } - - stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) + stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return @@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) { // DashboardTrend handles getting user usage trend data // GET /api/v1/usage/dashboard/trend func (h *UsageHandler) DashboardTrend(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } startTime, endTime := parseUserTimeRange(c) granularity := c.DefaultQuery("granularity", "day") - trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) + trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity) if err != nil { response.ErrorFrom(c, err) return @@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) { // DashboardModels handles getting user model usage statistics // GET /api/v1/usage/dashboard/models func (h *UsageHandler) DashboardModels(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } startTime, endTime := parseUserTimeRange(c) - stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) + stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime) if err != nil { response.ErrorFrom(c, err) return @@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct { // DashboardApiKeysUsage handles getting usage stats for user's own API keys // POST /api/v1/usage/dashboard/api-keys-usage func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { } // Verify ownership of all requested API keys - userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) + userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 4c5498f0..f4639b1f 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -1,8 +1,9 @@ package handler import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -35,19 +36,13 @@ type UpdateProfileRequest struct { // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { response.Unauthorized(c, "User not authenticated") return } - user, ok := userValue.(*model.User) - if !ok { - response.InternalError(c, "Invalid user context") - return - } - - userData, err := h.userService.GetByID(c.Request.Context(), user.ID) + userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return @@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) { // 清空notes字段,普通用户不应看到备注 userData.Notes = "" - response.Success(c, userData) + response.Success(c, dto.UserFromService(userData)) } // ChangePassword handles changing user password // POST /api/v1/users/me/password func (h *UserHandler) ChangePassword(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { CurrentPassword: req.OldPassword, NewPassword: req.NewPassword, } - err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) + err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq) if err != nil { response.ErrorFrom(c, err) return @@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { // UpdateProfile handles updating user profile // PUT /api/v1/users/me func (h *UserHandler) UpdateProfile(c *gin.Context) { - userValue, exists := c.Get("user") - if !exists { - response.Unauthorized(c, "User not authenticated") - return - } - - user, ok := userValue.(*model.User) + subject, ok := middleware2.GetAuthSubjectFromContext(c) if !ok { - response.InternalError(c, "Invalid user context") + response.Unauthorized(c, "User not authenticated") return } @@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { Username: req.Username, Wechat: req.Wechat, } - updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) + updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) if err != nil { response.ErrorFrom(c, err) return @@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { // 清空notes字段,普通用户不应看到备注 updatedUser.Notes = "" - response.Success(c, updatedUser) + response.Success(c, dto.UserFromService(updatedUser)) } diff --git a/backend/internal/infrastructure/database.go b/backend/internal/infrastructure/database.go index ffaf367b..da40bace 100644 --- a/backend/internal/infrastructure/database.go +++ b/backend/internal/infrastructure/database.go @@ -2,8 +2,8 @@ package infrastructure import ( "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/repository" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) { // 自动迁移(始终执行,确保数据库结构与代码同步) // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 - if err := model.AutoMigrate(db); err != nil { + if err := repository.AutoMigrate(db); err != nil { return nil, err } diff --git a/backend/internal/model/account.go b/backend/internal/model/account.go deleted file mode 100644 index 9b09b114..00000000 --- a/backend/internal/model/account.go +++ /dev/null @@ -1,415 +0,0 @@ -package model - -import ( - "database/sql/driver" - "encoding/json" - "errors" - "time" - - "gorm.io/gorm" -) - -// JSONB 用于存储JSONB数据 -type JSONB map[string]any - -func (j JSONB) Value() (driver.Value, error) { - if j == nil { - return nil, nil - } - return json.Marshal(j) -} - -func (j *JSONB) Scan(value any) error { - if value == nil { - *j = nil - return nil - } - bytes, ok := value.([]byte) - if !ok { - return errors.New("type assertion to []byte failed") - } - return json.Unmarshal(bytes, j) -} - -type Account struct { - ID int64 `gorm:"primaryKey" json:"id"` - Name string `gorm:"size:100;not null" json:"name"` - Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini - Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey - Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储) - Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息 - ProxyID *int64 `gorm:"index" json:"proxy_id"` - Concurrency int `gorm:"default:3;not null" json:"concurrency"` - Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高 - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error - ErrorMessage string `gorm:"type:text" json:"error_message"` - LastUsedAt *time.Time `gorm:"index" json:"last_used_at"` - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` - - // 调度控制 - Schedulable bool `gorm:"default:true;not null" json:"schedulable"` - - // 限流状态 (429) - RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"` - RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"` - - // 过载状态 (529) - OverloadUntil *time.Time `gorm:"index" json:"overload_until"` - - // 5小时时间窗口 - SessionWindowStart *time.Time `json:"session_window_start"` - SessionWindowEnd *time.Time `json:"session_window_end"` - SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected - - // 关联 - Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"` - AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"` - - // 虚拟字段 (不存储到数据库) - GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"` - Groups []*Group `gorm:"-" json:"groups,omitempty"` -} - -func (Account) TableName() string { - return "accounts" -} - -// IsActive 检查是否激活 -func (a *Account) IsActive() bool { - return a.Status == "active" -} - -// IsSchedulable 检查账号是否可调度 -func (a *Account) IsSchedulable() bool { - if !a.IsActive() || !a.Schedulable { - return false - } - now := time.Now() - if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) { - return false - } - if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) { - return false - } - return true -} - -// IsRateLimited 检查是否处于限流状态 -func (a *Account) IsRateLimited() bool { - if a.RateLimitResetAt == nil { - return false - } - return time.Now().Before(*a.RateLimitResetAt) -} - -// IsOverloaded 检查是否处于过载状态 -func (a *Account) IsOverloaded() bool { - if a.OverloadUntil == nil { - return false - } - return time.Now().Before(*a.OverloadUntil) -} - -// IsOAuth 检查是否为OAuth类型账号(包括oauth和setup-token) -func (a *Account) IsOAuth() bool { - return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken -} - -// CanGetUsage 检查账号是否可以获取usage信息(只有oauth类型可以,setup-token没有profile权限) -func (a *Account) CanGetUsage() bool { - return a.Type == AccountTypeOAuth -} - -// GetCredential 获取凭证字段 -func (a *Account) GetCredential(key string) string { - if a.Credentials == nil { - return "" - } - if v, ok := a.Credentials[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return "" -} - -// GetModelMapping 获取模型映射配置 -// 返回格式: map[请求模型名]实际模型名 -func (a *Account) GetModelMapping() map[string]string { - if a.Credentials == nil { - return nil - } - raw, ok := a.Credentials["model_mapping"] - if !ok || raw == nil { - return nil - } - // 处理map[string]interface{}类型 - if m, ok := raw.(map[string]any); ok { - result := make(map[string]string) - for k, v := range m { - if s, ok := v.(string); ok { - result[k] = s - } - } - if len(result) > 0 { - return result - } - } - return nil -} - -// IsModelSupported 检查请求的模型是否被该账号支持 -// 如果没有设置模型映射,则支持所有模型 -func (a *Account) IsModelSupported(requestedModel string) bool { - mapping := a.GetModelMapping() - if len(mapping) == 0 { - return true // 没有映射配置,支持所有模型 - } - _, exists := mapping[requestedModel] - return exists -} - -// GetMappedModel 获取映射后的实际模型名 -// 如果没有映射,返回原始模型名 -func (a *Account) GetMappedModel(requestedModel string) string { - mapping := a.GetModelMapping() - if len(mapping) == 0 { - return requestedModel - } - if mappedModel, exists := mapping[requestedModel]; exists { - return mappedModel - } - return requestedModel -} - -// GetBaseURL 获取API基础URL(用于apikey类型账号) -func (a *Account) GetBaseURL() string { - if a.Type != AccountTypeApiKey { - return "" - } - baseURL := a.GetCredential("base_url") - if baseURL == "" { - return "https://api.anthropic.com" // 默认URL - } - return baseURL -} - -// GetExtraString 从Extra字段获取字符串值 -func (a *Account) GetExtraString(key string) string { - if a.Extra == nil { - return "" - } - if v, ok := a.Extra[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return "" -} - -// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型) -func (a *Account) IsCustomErrorCodesEnabled() bool { - if a.Type != AccountTypeApiKey || a.Credentials == nil { - return false - } - if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { - if enabled, ok := v.(bool); ok { - return enabled - } - } - return false -} - -// GetCustomErrorCodes 获取自定义错误码列表 -func (a *Account) GetCustomErrorCodes() []int { - if a.Credentials == nil { - return nil - } - raw, ok := a.Credentials["custom_error_codes"] - if !ok || raw == nil { - return nil - } - // 处理 []interface{} 类型(JSON反序列化后的格式) - if arr, ok := raw.([]any); ok { - result := make([]int, 0, len(arr)) - for _, v := range arr { - // JSON 数字默认解析为 float64 - if f, ok := v.(float64); ok { - result = append(result, int(f)) - } - } - return result - } - return nil -} - -// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等) -// 如果未启用自定义错误码或列表为空,返回 true(使用默认策略) -// 如果启用且列表非空,只有在列表中的错误码才返回 true -func (a *Account) ShouldHandleErrorCode(statusCode int) bool { - if !a.IsCustomErrorCodesEnabled() { - return true // 未启用,使用默认策略 - } - codes := a.GetCustomErrorCodes() - if len(codes) == 0 { - return true // 启用但列表为空,fallback到默认策略 - } - // 检查是否在自定义列表中 - for _, code := range codes { - if code == statusCode { - return true - } - } - return false -} - -// IsInterceptWarmupEnabled 检查是否启用预热请求拦截 -// 启用后,标题生成、Warmup等预热请求将返回mock响应,不消耗上游token -func (a *Account) IsInterceptWarmupEnabled() bool { - if a.Credentials == nil { - return false - } - if v, ok := a.Credentials["intercept_warmup_requests"]; ok { - if enabled, ok := v.(bool); ok { - return enabled - } - } - return false -} - -// =============== OpenAI 相关方法 =============== - -// IsOpenAI 检查是否为 OpenAI 平台账号 -func (a *Account) IsOpenAI() bool { - return a.Platform == PlatformOpenAI -} - -// IsAnthropic 检查是否为 Anthropic 平台账号 -func (a *Account) IsAnthropic() bool { - return a.Platform == PlatformAnthropic -} - -// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号 -func (a *Account) IsOpenAIOAuth() bool { - return a.IsOpenAI() && a.Type == AccountTypeOAuth -} - -// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号(Response 账号) -func (a *Account) IsOpenAIApiKey() bool { - return a.IsOpenAI() && a.Type == AccountTypeApiKey -} - -// GetOpenAIBaseURL 获取 OpenAI API 基础 URL -// 对于 API Key 类型账号,从 credentials 中获取 base_url -// 对于 OAuth 类型账号,返回默认的 OpenAI API URL -func (a *Account) GetOpenAIBaseURL() string { - if !a.IsOpenAI() { - return "" - } - if a.Type == AccountTypeApiKey { - baseURL := a.GetCredential("base_url") - if baseURL != "" { - return baseURL - } - } - return "https://api.openai.com" // OpenAI 默认 API URL -} - -// GetOpenAIAccessToken 获取 OpenAI 访问令牌 -func (a *Account) GetOpenAIAccessToken() string { - if !a.IsOpenAI() { - return "" - } - return a.GetCredential("access_token") -} - -// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌 -func (a *Account) GetOpenAIRefreshToken() string { - if !a.IsOpenAIOAuth() { - return "" - } - return a.GetCredential("refresh_token") -} - -// GetOpenAIIDToken 获取 OpenAI ID Token(JWT,包含用户信息) -func (a *Account) GetOpenAIIDToken() string { - if !a.IsOpenAIOAuth() { - return "" - } - return a.GetCredential("id_token") -} - -// GetOpenAIApiKey 获取 OpenAI API Key(用于 Response 账号) -func (a *Account) GetOpenAIApiKey() string { - if !a.IsOpenAIApiKey() { - return "" - } - return a.GetCredential("api_key") -} - -// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent -// 返回空字符串表示透传原始 User-Agent -func (a *Account) GetOpenAIUserAgent() string { - if !a.IsOpenAI() { - return "" - } - return a.GetCredential("user_agent") -} - -// GetChatGPTAccountID 获取 ChatGPT 账号 ID(从 ID Token 解析) -func (a *Account) GetChatGPTAccountID() string { - if !a.IsOpenAIOAuth() { - return "" - } - return a.GetCredential("chatgpt_account_id") -} - -// GetChatGPTUserID 获取 ChatGPT 用户 ID(从 ID Token 解析) -func (a *Account) GetChatGPTUserID() string { - if !a.IsOpenAIOAuth() { - return "" - } - return a.GetCredential("chatgpt_user_id") -} - -// GetOpenAIOrganizationID 获取 OpenAI 组织 ID -func (a *Account) GetOpenAIOrganizationID() string { - if !a.IsOpenAIOAuth() { - return "" - } - return a.GetCredential("organization_id") -} - -// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间 -func (a *Account) GetOpenAITokenExpiresAt() *time.Time { - if !a.IsOpenAIOAuth() { - return nil - } - expiresAtStr := a.GetCredential("expires_at") - if expiresAtStr == "" { - return nil - } - // 尝试解析时间 - t, err := time.Parse(time.RFC3339, expiresAtStr) - if err != nil { - // 尝试解析为 Unix 时间戳 - if v, ok := a.Credentials["expires_at"].(float64); ok { - t = time.Unix(int64(v), 0) - return &t - } - return nil - } - return &t -} - -// IsOpenAITokenExpired 检查 OpenAI Token 是否过期 -func (a *Account) IsOpenAITokenExpired() bool { - expiresAt := a.GetOpenAITokenExpiresAt() - if expiresAt == nil { - return false // 没有过期时间信息,假设未过期 - } - // 提前 60 秒认为过期,便于刷新 - return time.Now().Add(60 * time.Second).After(*expiresAt) -} diff --git a/backend/internal/model/account_group.go b/backend/internal/model/account_group.go deleted file mode 100644 index 9f48b6ce..00000000 --- a/backend/internal/model/account_group.go +++ /dev/null @@ -1,20 +0,0 @@ -package model - -import ( - "time" -) - -type AccountGroup struct { - AccountID int64 `gorm:"primaryKey" json:"account_id"` - GroupID int64 `gorm:"primaryKey" json:"group_id"` - Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级 - CreatedAt time.Time `gorm:"not null" json:"created_at"` - - // 关联 - Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"` - Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"` -} - -func (AccountGroup) TableName() string { - return "account_groups" -} diff --git a/backend/internal/model/api_key.go b/backend/internal/model/api_key.go deleted file mode 100644 index 11017081..00000000 --- a/backend/internal/model/api_key.go +++ /dev/null @@ -1,32 +0,0 @@ -package model - -import ( - "time" - - "gorm.io/gorm" -) - -type ApiKey struct { - ID int64 `gorm:"primaryKey" json:"id"` - UserID int64 `gorm:"index;not null" json:"user_id"` - Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx - Name string `gorm:"size:100;not null" json:"name"` - GroupID *int64 `gorm:"index" json:"group_id"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` - - // 关联 - User *User `gorm:"foreignKey:UserID" json:"user,omitempty"` - Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"` -} - -func (ApiKey) TableName() string { - return "api_keys" -} - -// IsActive 检查是否激活 -func (k *ApiKey) IsActive() bool { - return k.Status == "active" -} diff --git a/backend/internal/model/group.go b/backend/internal/model/group.go deleted file mode 100644 index b1bbe527..00000000 --- a/backend/internal/model/group.go +++ /dev/null @@ -1,73 +0,0 @@ -package model - -import ( - "time" - - "gorm.io/gorm" -) - -// 订阅类型常量 -const ( - SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费) - SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制) -) - -type Group struct { - ID int64 `gorm:"primaryKey" json:"id"` - Name string `gorm:"uniqueIndex;size:100;not null" json:"name"` - Description string `gorm:"type:text" json:"description"` - Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini - RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"` - IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled - - // 订阅功能字段 - SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription - DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"` - WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"` - MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"` - - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` - - // 关联 - AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"` - - // 虚拟字段 (不存储到数据库) - AccountCount int64 `gorm:"-" json:"account_count,omitempty"` -} - -func (Group) TableName() string { - return "groups" -} - -// IsActive 检查是否激活 -func (g *Group) IsActive() bool { - return g.Status == "active" -} - -// IsSubscriptionType 检查是否为订阅类型分组 -func (g *Group) IsSubscriptionType() bool { - return g.SubscriptionType == SubscriptionTypeSubscription -} - -// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额) -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - -// HasDailyLimit 检查是否有日限额 -func (g *Group) HasDailyLimit() bool { - return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 -} - -// HasWeeklyLimit 检查是否有周限额 -func (g *Group) HasWeeklyLimit() bool { - return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0 -} - -// HasMonthlyLimit 检查是否有月限额 -func (g *Group) HasMonthlyLimit() bool { - return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0 -} diff --git a/backend/internal/model/model.go b/backend/internal/model/model.go deleted file mode 100644 index 203f552a..00000000 --- a/backend/internal/model/model.go +++ /dev/null @@ -1,64 +0,0 @@ -package model - -import ( - "gorm.io/gorm" -) - -// AutoMigrate 自动迁移所有模型 -func AutoMigrate(db *gorm.DB) error { - return db.AutoMigrate( - &User{}, - &ApiKey{}, - &Group{}, - &Account{}, - &AccountGroup{}, - &Proxy{}, - &RedeemCode{}, - &UsageLog{}, - &Setting{}, - &UserSubscription{}, - ) -} - -// 状态常量 -const ( - StatusActive = "active" - StatusDisabled = "disabled" - StatusError = "error" - StatusUnused = "unused" - StatusUsed = "used" - StatusExpired = "expired" -) - -// 角色常量 -const ( - RoleAdmin = "admin" - RoleUser = "user" -) - -// 平台常量 -const ( - PlatformAnthropic = "anthropic" - PlatformOpenAI = "openai" - PlatformGemini = "gemini" -) - -// 账号类型常量 -const ( - AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) - AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) - AccountTypeApiKey = "apikey" // API Key类型账号 -) - -// 卡密类型常量 -const ( - RedeemTypeBalance = "balance" - RedeemTypeConcurrency = "concurrency" - RedeemTypeSubscription = "subscription" -) - -// 管理员调整类型常量 -const ( - AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额 - AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数 -) diff --git a/backend/internal/model/proxy.go b/backend/internal/model/proxy.go deleted file mode 100644 index af27dbe6..00000000 --- a/backend/internal/model/proxy.go +++ /dev/null @@ -1,45 +0,0 @@ -package model - -import ( - "fmt" - "time" - - "gorm.io/gorm" -) - -type Proxy struct { - ID int64 `gorm:"primaryKey" json:"id"` - Name string `gorm:"size:100;not null" json:"name"` - Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5 - Host string `gorm:"size:255;not null" json:"host"` - Port int `gorm:"not null" json:"port"` - Username string `gorm:"size:100" json:"username"` - Password string `gorm:"size:100" json:"-"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` -} - -func (Proxy) TableName() string { - return "proxies" -} - -// IsActive 检查是否激活 -func (p *Proxy) IsActive() bool { - return p.Status == "active" -} - -// URL 返回代理URL -func (p *Proxy) URL() string { - if p.Username != "" && p.Password != "" { - return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port) - } - return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port) -} - -// ProxyWithAccountCount extends Proxy with account count information -type ProxyWithAccountCount struct { - Proxy - AccountCount int64 `json:"account_count"` -} diff --git a/backend/internal/model/redeem_code.go b/backend/internal/model/redeem_code.go deleted file mode 100644 index 6857c410..00000000 --- a/backend/internal/model/redeem_code.go +++ /dev/null @@ -1,50 +0,0 @@ -package model - -import ( - "crypto/rand" - "encoding/hex" - "time" -) - -type RedeemCode struct { - ID int64 `gorm:"primaryKey" json:"id"` - Code string `gorm:"uniqueIndex;size:32;not null" json:"code"` - Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription - Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数 - Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used - UsedBy *int64 `gorm:"index" json:"used_by"` - UsedAt *time.Time `json:"used_at"` - Notes string `gorm:"type:text" json:"notes"` - CreatedAt time.Time `gorm:"not null" json:"created_at"` - - // 订阅类型专用字段 - GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用) - ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用) - - // 关联 - User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"` - Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"` -} - -func (RedeemCode) TableName() string { - return "redeem_codes" -} - -// IsUsed 检查是否已使用 -func (r *RedeemCode) IsUsed() bool { - return r.Status == "used" -} - -// CanUse 检查是否可以使用 -func (r *RedeemCode) CanUse() bool { - return r.Status == "unused" -} - -// GenerateRedeemCode 生成唯一的兑换码 -func GenerateRedeemCode() (string, error) { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return "", err - } - return hex.EncodeToString(b), nil -} diff --git a/backend/internal/model/setting.go b/backend/internal/model/setting.go deleted file mode 100644 index ec937e16..00000000 --- a/backend/internal/model/setting.go +++ /dev/null @@ -1,104 +0,0 @@ -package model - -import ( - "time" -) - -// Setting 系统设置模型(Key-Value存储) -type Setting struct { - ID int64 `gorm:"primaryKey" json:"id"` - Key string `gorm:"uniqueIndex;size:100;not null" json:"key"` - Value string `gorm:"type:text;not null" json:"value"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` -} - -func (Setting) TableName() string { - return "settings" -} - -// 设置Key常量 -const ( - // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - - // 邮件服务设置 - SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 - SettingKeySmtpPort = "smtp_port" // SMTP端口 - SettingKeySmtpUsername = "smtp_username" // SMTP用户名 - SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) - SettingKeySmtpFrom = "smtp_from" // 发件人地址 - SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 - SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS - - // Cloudflare Turnstile 设置 - SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 - SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key - SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key - - // OEM设置 - SettingKeySiteName = "site_name" // 网站名称 - SettingKeySiteLogo = "site_logo" // 网站Logo (base64) - SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 - SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入) - SettingKeyContactInfo = "contact_info" // 客服联系方式 - SettingKeyDocUrl = "doc_url" // 文档链接 - - // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 - - // 管理员 API Key - SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) -) - -// 管理员 API Key 前缀(与用户 sk- 前缀区分) -const AdminApiKeyPrefix = "admin-" - -// SystemSettings 系统设置结构体(用于API响应) -type SystemSettings struct { - // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - - // 邮件服务设置 - SmtpHost string `json:"smtp_host"` - SmtpPort int `json:"smtp_port"` - SmtpUsername string `json:"smtp_username"` - SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码 - SmtpFrom string `json:"smtp_from_email"` - SmtpFromName string `json:"smtp_from_name"` - SmtpUseTLS bool `json:"smtp_use_tls"` - - // Cloudflare Turnstile 设置 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥 - - // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - ApiBaseUrl string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocUrl string `json:"doc_url"` - - // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` -} - -// PublicSettings 公开设置(无需登录即可获取) -type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - ApiBaseUrl string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocUrl string `json:"doc_url"` - Version string `json:"version"` -} diff --git a/backend/internal/model/usage_log.go b/backend/internal/model/usage_log.go deleted file mode 100644 index b9ca0a77..00000000 --- a/backend/internal/model/usage_log.go +++ /dev/null @@ -1,67 +0,0 @@ -package model - -import ( - "time" -) - -// 消费类型常量 -const ( - BillingTypeBalance int8 = 0 // 钱包余额 - BillingTypeSubscription int8 = 1 // 订阅套餐 -) - -type UsageLog struct { - ID int64 `gorm:"primaryKey" json:"id"` - UserID int64 `gorm:"index;not null" json:"user_id"` - ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"` - AccountID int64 `gorm:"index;not null" json:"account_id"` - RequestID string `gorm:"size:64" json:"request_id"` - Model string `gorm:"size:100;index;not null" json:"model"` - - // 订阅关联(可选) - GroupID *int64 `gorm:"index" json:"group_id"` - SubscriptionID *int64 `gorm:"index" json:"subscription_id"` - - // Token使用量(4类) - InputTokens int `gorm:"default:0;not null" json:"input_tokens"` - OutputTokens int `gorm:"default:0;not null" json:"output_tokens"` - CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"` - CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"` - - // 详细的缓存创建分类 - CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"` - CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"` - - // 费用(USD) - InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"` - OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"` - CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"` - CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"` - TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用 - ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用 - RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率 - - // 元数据 - BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅 - Stream bool `gorm:"default:false;not null" json:"stream"` - DurationMs *int `json:"duration_ms"` - FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求) - - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - - // 关联 - User *User `gorm:"foreignKey:UserID" json:"user,omitempty"` - ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"` - Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"` - Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"` - Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"` -} - -func (UsageLog) TableName() string { - return "usage_logs" -} - -// TotalTokens 总token数 -func (u *UsageLog) TotalTokens() int { - return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens -} diff --git a/backend/internal/model/user.go b/backend/internal/model/user.go deleted file mode 100644 index 2f90d8a1..00000000 --- a/backend/internal/model/user.go +++ /dev/null @@ -1,78 +0,0 @@ -package model - -import ( - "time" - - "github.com/lib/pq" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" -) - -type User struct { - ID int64 `gorm:"primaryKey" json:"id"` - Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` - Username string `gorm:"size:100;default:''" json:"username"` - Wechat string `gorm:"size:100;default:''" json:"wechat"` - Notes string `gorm:"type:text;default:''" json:"notes"` - PasswordHash string `gorm:"size:255;not null" json:"-"` - Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user - Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"` - Concurrency int `gorm:"default:5;not null" json:"concurrency"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled - AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"` - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` - - // 关联 - ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"` - Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"` -} - -func (User) TableName() string { - return "users" -} - -// IsAdmin 检查是否管理员 -func (u *User) IsAdmin() bool { - return u.Role == "admin" -} - -// IsActive 检查是否激活 -func (u *User) IsActive() bool { - return u.Status == "active" -} - -// CanBindGroup 检查是否可以绑定指定分组 -// 对于标准类型分组: -// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组 -// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组 -func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { - // 如果设置了 allowed_groups 且不为空,只能绑定指定的分组 - if len(u.AllowedGroups) > 0 { - for _, id := range u.AllowedGroups { - if id == groupID { - return true - } - } - return false - } - // 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组 - return !isExclusive -} - -// SetPassword 设置密码(哈希存储) -func (u *User) SetPassword(password string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return err - } - u.PasswordHash = string(hash) - return nil -} - -// CheckPassword 验证密码 -func (u *User) CheckPassword(password string) bool { - err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) - return err == nil -} diff --git a/backend/internal/model/user_subscription.go b/backend/internal/model/user_subscription.go deleted file mode 100644 index 2bdcd1b5..00000000 --- a/backend/internal/model/user_subscription.go +++ /dev/null @@ -1,157 +0,0 @@ -package model - -import ( - "time" -) - -// 订阅状态常量 -const ( - SubscriptionStatusActive = "active" - SubscriptionStatusExpired = "expired" - SubscriptionStatusSuspended = "suspended" -) - -// UserSubscription 用户订阅模型 -type UserSubscription struct { - ID int64 `gorm:"primaryKey" json:"id"` - UserID int64 `gorm:"index;not null" json:"user_id"` - GroupID int64 `gorm:"index;not null" json:"group_id"` - - // 订阅有效期 - StartsAt time.Time `gorm:"not null" json:"starts_at"` - ExpiresAt time.Time `gorm:"not null" json:"expires_at"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended - - // 滑动窗口起始时间(nil = 未激活) - DailyWindowStart *time.Time `json:"daily_window_start"` - WeeklyWindowStart *time.Time `json:"weekly_window_start"` - MonthlyWindowStart *time.Time `json:"monthly_window_start"` - - // 当前窗口已用额度(USD,基于 total_cost 计算) - DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"` - WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"` - MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"` - - // 管理员分配信息 - AssignedBy *int64 `gorm:"index" json:"assigned_by"` - AssignedAt time.Time `gorm:"not null" json:"assigned_at"` - Notes string `gorm:"type:text" json:"notes"` - - CreatedAt time.Time `gorm:"not null" json:"created_at"` - UpdatedAt time.Time `gorm:"not null" json:"updated_at"` - - // 关联 - User *User `gorm:"foreignKey:UserID" json:"user,omitempty"` - Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"` - AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"` -} - -func (UserSubscription) TableName() string { - return "user_subscriptions" -} - -// IsActive 检查订阅是否有效(状态为active且未过期) -func (s *UserSubscription) IsActive() bool { - return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt) -} - -// IsExpired 检查订阅是否已过期 -func (s *UserSubscription) IsExpired() bool { - return time.Now().After(s.ExpiresAt) -} - -// DaysRemaining 返回订阅剩余天数 -func (s *UserSubscription) DaysRemaining() int { - if s.IsExpired() { - return 0 - } - return int(time.Until(s.ExpiresAt).Hours() / 24) -} - -// IsWindowActivated 检查窗口是否已激活 -func (s *UserSubscription) IsWindowActivated() bool { - return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil -} - -// NeedsDailyReset 检查日窗口是否需要重置 -func (s *UserSubscription) NeedsDailyReset() bool { - if s.DailyWindowStart == nil { - return false - } - return time.Since(*s.DailyWindowStart) >= 24*time.Hour -} - -// NeedsWeeklyReset 检查周窗口是否需要重置 -func (s *UserSubscription) NeedsWeeklyReset() bool { - if s.WeeklyWindowStart == nil { - return false - } - return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour -} - -// NeedsMonthlyReset 检查月窗口是否需要重置 -func (s *UserSubscription) NeedsMonthlyReset() bool { - if s.MonthlyWindowStart == nil { - return false - } - return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour -} - -// DailyResetTime 返回日窗口重置时间 -func (s *UserSubscription) DailyResetTime() *time.Time { - if s.DailyWindowStart == nil { - return nil - } - t := s.DailyWindowStart.Add(24 * time.Hour) - return &t -} - -// WeeklyResetTime 返回周窗口重置时间 -func (s *UserSubscription) WeeklyResetTime() *time.Time { - if s.WeeklyWindowStart == nil { - return nil - } - t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour) - return &t -} - -// MonthlyResetTime 返回月窗口重置时间 -func (s *UserSubscription) MonthlyResetTime() *time.Time { - if s.MonthlyWindowStart == nil { - return nil - } - t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour) - return &t -} - -// CheckDailyLimit 检查是否超出日限额 -func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool { - if !group.HasDailyLimit() { - return true // 无限制 - } - return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD -} - -// CheckWeeklyLimit 检查是否超出周限额 -func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool { - if !group.HasWeeklyLimit() { - return true // 无限制 - } - return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD -} - -// CheckMonthlyLimit 检查是否超出月限额 -func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool { - if !group.HasMonthlyLimit() { - return true // 无限制 - } - return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD -} - -// CheckAllLimits 检查所有限额 -func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) { - daily = s.CheckDailyLimit(group, additionalCost) - weekly = s.CheckWeeklyLimit(group, additionalCost) - monthly = s.CheckMonthlyLimit(group, additionalCost) - return -} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d07bc741..6074d3f5 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -5,10 +5,10 @@ import ( "errors" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository { return &accountRepository{db: db} } -func (r *accountRepository) Create(ctx context.Context, account *model.Account) error { - return r.db.WithContext(ctx).Create(account).Error +func (r *accountRepository) Create(ctx context.Context, account *service.Account) error { + m := accountModelFromService(account) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyAccountModelToService(account, m) + } + return err } -func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { - var account model.Account - err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error +func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) { + var m accountModel + err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) } - // 填充 GroupIDs 和 Groups 虚拟字段 - account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) - account.Groups = make([]*model.Group, 0, len(account.AccountGroups)) - for _, ag := range account.AccountGroups { - account.GroupIDs = append(account.GroupIDs, ag.GroupID) - if ag.Group != nil { - account.Groups = append(account.Groups, ag.Group) - } - } - return &account, nil + return accountModelToService(&m), nil } -func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { +func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { if crsAccountID == "" { return nil, nil } - var account model.Account - err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error + var m accountModel + err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } - return &account, nil + return accountModelToService(&m), nil } -func (r *accountRepository) Update(ctx context.Context, account *model.Account) error { - return r.db.WithContext(ctx).Save(account).Error +func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { + m := accountModelFromService(account) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyAccountModelToService(account, m) + } + return err } func (r *accountRepository) Delete(ctx context.Context, id int64) error { - // 先删除账号与分组的绑定关系 - if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { + if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil { return err } - // 再删除账号 - return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error + return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error } -func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "", "") } -// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { - var accounts []model.Account +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { + var accounts []accountModel var total int64 - db := r.db.WithContext(ctx).Model(&model.Account{}) + db := r.db.WithContext(ctx).Model(&accountModel{}) - // Apply filters if platform != "" { db = db.Where("platform = ?", platform) } @@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati return nil, nil, err } - // 填充每个 Account 的虚拟字段(GroupIDs 和 Groups) + outAccounts := make([]service.Account, 0, len(accounts)) for i := range accounts { - accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups)) - accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups)) - for _, ag := range accounts[i].AccountGroups { - accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID) - if ag.Group != nil { - accounts[i].Groups = append(accounts[i].Groups, ag.Group) - } - } + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return accounts, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outAccounts, paginationResultFromTotal(total, params), nil } -func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { + var accounts []accountModel err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). - Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive). + Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive). Preload("Proxy"). Order("account_groups.priority ASC, accounts.priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } -func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) { + var accounts []accountModel err := r.db.WithContext(ctx). - Where("status = ?", model.StatusActive). + Where("status = ?", service.StatusActive). Preload("Proxy"). Order("priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil +} + +func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + var accounts []accountModel + err := r.db.WithContext(ctx). + Where("platform = ? AND status = ?", platform, service.StatusActive). + Preload("Proxy"). + Order("priority ASC"). + Find(&accounts).Error + if err != nil { + return nil, err + } + + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { now := time.Now() - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error } func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Updates(map[string]any{ - "status": model.StatusError, + "status": service.StatusError, "error_message": errorMsg, }).Error } func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { - ag := &model.AccountGroup{ + ag := &accountGroupModel{ AccountID: accountID, GroupID: groupID, Priority: priority, @@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). - Delete(&model.AccountGroup{}).Error + Delete(&accountGroupModel{}).Error } -func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { - var groups []model.Group +func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) { + var groups []groupModel err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Where("account_groups.account_id = ?", accountID). Find(&groups).Error - return groups, err -} + if err != nil { + return nil, err + } -func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { - var accounts []model.Account - err := r.db.WithContext(ctx). - Where("platform = ? AND status = ?", platform, model.StatusActive). - Preload("Proxy"). - Order("priority ASC"). - Find(&accounts).Error - return accounts, err + outGroups := make([]service.Group, 0, len(groups)) + for i := range groups { + outGroups = append(outGroups, *groupModelToService(&groups[i])) + } + return outGroups, nil } func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { - // 删除现有绑定 - if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil { + if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil { return err } - // 添加新绑定 - if len(groupIDs) > 0 { - accountGroups := make([]model.AccountGroup, 0, len(groupIDs)) - for i, groupID := range groupIDs { - accountGroups = append(accountGroups, model.AccountGroup{ - AccountID: accountID, - GroupID: groupID, - Priority: i + 1, // 使用索引作为优先级 - }) - } - return r.db.WithContext(ctx).Create(&accountGroups).Error + if len(groupIDs) == 0 { + return nil } - return nil + accountGroups := make([]accountGroupModel, 0, len(groupIDs)) + for i, groupID := range groupIDs { + accountGroups = append(accountGroups, accountGroupModel{ + AccountID: accountID, + GroupID: groupID, + Priority: i + 1, + }) + } + return r.db.WithContext(ctx).Create(&accountGroups).Error } -// ListSchedulable 获取所有可调度的账号 -func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) { + var accounts []accountModel now := time.Now() err := r.db.WithContext(ctx). - Where("status = ? AND schedulable = ?", model.StatusActive, true). + Where("status = ? AND schedulable = ?", service.StatusActive, true). Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Preload("Proxy"). Order("priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } -// ListSchedulableByGroupID 按组获取可调度的账号 -func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + var accounts []accountModel now := time.Now() err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Where("account_groups.group_id = ?", groupID). - Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). + Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Preload("Proxy"). Order("account_groups.priority ASC, accounts.priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } -// ListSchedulableByPlatform 按平台获取可调度的账号 -func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + var accounts []accountModel now := time.Now() err := r.db.WithContext(ctx). Where("platform = ?", platform). - Where("status = ? AND schedulable = ?", model.StatusActive, true). + Where("status = ? AND schedulable = ?", service.StatusActive, true). Where("(overload_until IS NULL OR overload_until <= ?)", now). Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now). Preload("Proxy"). Order("priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } -// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 -func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { - var accounts []model.Account +func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + var accounts []accountModel now := time.Now() err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Where("account_groups.group_id = ?", groupID). Where("accounts.platform = ?", platform). - Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true). + Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true). Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now). Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now). Preload("Proxy"). Order("account_groups.priority ASC, accounts.priority ASC"). Find(&accounts).Error - return accounts, err + if err != nil { + return nil, err + } + outAccounts := make([]service.Account, 0, len(accounts)) + for i := range accounts { + outAccounts = append(outAccounts, *accountModelToService(&accounts[i])) + } + return outAccounts, nil } -// SetRateLimited 标记账号为限流状态(429) func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { now := time.Now() - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Updates(map[string]any{ "rate_limited_at": now, "rate_limit_reset_at": resetAt, }).Error } -// SetOverloaded 标记账号为过载状态(529) func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Update("overload_until", until).Error } -// ClearRateLimit 清除账号的限流状态 func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Updates(map[string]any{ "rate_limited_at": nil, "rate_limit_reset_at": nil, @@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error }).Error } -// UpdateSessionWindow 更新账号的5小时时间窗口信息 func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { updates := map[string]any{ "session_window_status": status, @@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s if end != nil { updates["session_window_end"] = end } - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error } -// SetSchedulable 设置账号的调度开关 func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Update("schedulable", schedulable).Error } -// UpdateExtra updates specific fields in account's Extra JSONB field -// It merges the updates into existing Extra data without overwriting other fields func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { if len(updates) == 0 { return nil } - // Get current account to preserve existing Extra data - var account model.Account + var account accountModel if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil { return err } - // Initialize Extra if nil if account.Extra == nil { - account.Extra = make(model.JSONB) + account.Extra = datatypes.JSONMap{} } - - // Merge updates into existing Extra for k, v := range updates { account.Extra[k] = v } - // Save updated Extra - return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id). Update("extra", account.Extra).Error } -// BulkUpdate updates multiple accounts with the provided fields. -// It merges credentials/extra JSONB fields instead of overwriting them. func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil @@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates updateMap["status"] = *updates.Status } if len(updates.Credentials) > 0 { - updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials) + updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials)) } if len(updates.Extra) > 0 { - updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra) + updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra)) } if len(updateMap) == 0 { @@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates } result := r.db.WithContext(ctx). - Model(&model.Account{}). + Model(&accountModel{}). Where("id IN ?", ids). Clauses(clause.Returning{}). Updates(updateMap) return result.RowsAffected, result.Error } + +type accountModel struct { + ID int64 `gorm:"primaryKey"` + Name string `gorm:"size:100;not null"` + Platform string `gorm:"size:50;not null"` + Type string `gorm:"size:20;not null"` + Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"` + Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"` + ProxyID *int64 `gorm:"index"` + Concurrency int `gorm:"default:3;not null"` + Priority int `gorm:"default:50;not null"` + Status string `gorm:"size:20;default:active;not null"` + ErrorMessage string `gorm:"type:text"` + LastUsedAt *time.Time `gorm:"index"` + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + DeletedAt gorm.DeletedAt `gorm:"index"` + + Schedulable bool `gorm:"default:true;not null"` + + RateLimitedAt *time.Time `gorm:"index"` + RateLimitResetAt *time.Time `gorm:"index"` + OverloadUntil *time.Time `gorm:"index"` + + SessionWindowStart *time.Time + SessionWindowEnd *time.Time + SessionWindowStatus string `gorm:"size:20"` + + Proxy *proxyModel `gorm:"foreignKey:ProxyID"` + AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"` +} + +func (accountModel) TableName() string { return "accounts" } + +type accountGroupModel struct { + AccountID int64 `gorm:"primaryKey"` + GroupID int64 `gorm:"primaryKey"` + Priority int `gorm:"default:50;not null"` + CreatedAt time.Time `gorm:"not null"` + + Account *accountModel `gorm:"foreignKey:AccountID"` + Group *groupModel `gorm:"foreignKey:GroupID"` +} + +func (accountGroupModel) TableName() string { return "account_groups" } + +func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup { + if m == nil { + return nil + } + return &service.AccountGroup{ + AccountID: m.AccountID, + GroupID: m.GroupID, + Priority: m.Priority, + CreatedAt: m.CreatedAt, + Account: accountModelToService(m.Account), + Group: groupModelToService(m.Group), + } +} + +func accountModelToService(m *accountModel) *service.Account { + if m == nil { + return nil + } + + var credentials map[string]any + if m.Credentials != nil { + credentials = map[string]any(m.Credentials) + } + + var extra map[string]any + if m.Extra != nil { + extra = map[string]any(m.Extra) + } + + account := &service.Account{ + ID: m.ID, + Name: m.Name, + Platform: m.Platform, + Type: m.Type, + Credentials: credentials, + Extra: extra, + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + Status: m.Status, + ErrorMessage: m.ErrorMessage, + LastUsedAt: m.LastUsedAt, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: m.SessionWindowStatus, + Proxy: proxyModelToService(m.Proxy), + } + + if len(m.AccountGroups) > 0 { + account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups)) + account.GroupIDs = make([]int64, 0, len(m.AccountGroups)) + account.Groups = make([]*service.Group, 0, len(m.AccountGroups)) + for i := range m.AccountGroups { + ag := accountGroupModelToService(&m.AccountGroups[i]) + if ag == nil { + continue + } + account.AccountGroups = append(account.AccountGroups, *ag) + account.GroupIDs = append(account.GroupIDs, ag.GroupID) + if ag.Group != nil { + account.Groups = append(account.Groups, ag.Group) + } + } + } + + return account +} + +func accountModelFromService(a *service.Account) *accountModel { + if a == nil { + return nil + } + + var credentials datatypes.JSONMap + if a.Credentials != nil { + credentials = datatypes.JSONMap(a.Credentials) + } + + var extra datatypes.JSONMap + if a.Extra != nil { + extra = datatypes.JSONMap(a.Extra) + } + + return &accountModel{ + ID: a.ID, + Name: a.Name, + Platform: a.Platform, + Type: a.Type, + Credentials: credentials, + Extra: extra, + ProxyID: a.ProxyID, + Concurrency: a.Concurrency, + Priority: a.Priority, + Status: a.Status, + ErrorMessage: a.ErrorMessage, + LastUsedAt: a.LastUsedAt, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, + Schedulable: a.Schedulable, + RateLimitedAt: a.RateLimitedAt, + RateLimitResetAt: a.RateLimitResetAt, + OverloadUntil: a.OverloadUntil, + SessionWindowStart: a.SessionWindowStart, + SessionWindowEnd: a.SessionWindowEnd, + SessionWindowStatus: a.SessionWindowStatus, + } +} + +func applyAccountModelToService(account *service.Account, m *accountModel) { + if account == nil || m == nil { + return + } + account.ID = m.ID + account.CreatedAt = m.CreatedAt + account.UpdatedAt = m.UpdatedAt +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 2f52ac5c..d35ce053 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" + "gorm.io/datatypes" "gorm.io/gorm" ) @@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) { // --- Create / GetByID / Update / Delete --- func (s *AccountRepoSuite) TestCreate() { - account := &model.Account{ - Name: "test-create", - Platform: model.PlatformAnthropic, - Type: model.AccountTypeOAuth, - Status: model.StatusActive, + account := &service.Account{ + Name: "test-create", + Platform: service.PlatformAnthropic, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + Credentials: map[string]any{}, + Extra: map[string]any{}, + Concurrency: 3, + Priority: 50, + Schedulable: true, } err := s.repo.Create(s.ctx, account) @@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() { } func (s *AccountRepoSuite) TestUpdate() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"}) + account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"})) account.Name = "updated" err := s.repo.Update(s.ctx, account) @@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() { } func (s *AccountRepoSuite) TestDelete() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"}) err := s.repo.Delete(s.ctx, account.ID) s.Require().NoError(err, "Delete") @@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() { } func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"}) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) err := s.repo.Delete(s.ctx, account.ID) s.Require().NoError(err, "Delete should cascade remove bindings") var count int64 - s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count) + s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count) s.Require().Zero(count, "expected bindings to be removed") } // --- List / ListWithFilters --- func (s *AccountRepoSuite) TestList() { - mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"}) accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") @@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() { status string search string wantCount int - validate func(accounts []model.Account) + validate func(accounts []service.Account) }{ { name: "filter_by_platform", setup: func(db *gorm.DB) { - mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic}) - mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI}) + mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic}) + mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI}) }, - platform: model.PlatformOpenAI, + platform: service.PlatformOpenAI, wantCount: 1, - validate: func(accounts []model.Account) { - s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform) + validate: func(accounts []service.Account) { + s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform) }, }, { name: "filter_by_type", setup: func(db *gorm.DB) { - mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth}) - mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey}) + mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth}) + mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey}) }, - accType: model.AccountTypeApiKey, + accType: service.AccountTypeApiKey, wantCount: 1, - validate: func(accounts []model.Account) { - s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type) + validate: func(accounts []service.Account) { + s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type) }, }, { name: "filter_by_status", setup: func(db *gorm.DB) { - mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive}) - mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled}) + mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive}) + mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled}) }, - status: model.StatusDisabled, + status: service.StatusDisabled, wantCount: 1, - validate: func(accounts []model.Account) { - s.Require().Equal(model.StatusDisabled, accounts[0].Status) + validate: func(accounts []service.Account) { + s.Require().Equal(service.StatusDisabled, accounts[0].Status) }, }, { name: "filter_by_search", setup: func(db *gorm.DB) { - mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"}) - mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"}) + mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"}) + mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"}) }, search: "alpha", wantCount: 1, - validate: func(accounts []model.Account) { + validate: func(accounts []service.Account) { s.Require().Contains(accounts[0].Name, "alpha") }, }, @@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() { // --- ListByGroup / ListActive / ListByPlatform --- func (s *AccountRepoSuite) TestListByGroup() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) - acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive}) - acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"}) + acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive}) + acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive}) mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2) mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1) @@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() { } func (s *AccountRepoSuite) TestListActive() { - mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled}) accounts, err := s.repo.ListActive(s.ctx) s.Require().NoError(err, "ListActive") @@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() { } func (s *AccountRepoSuite) TestListByPlatform() { - mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive}) - accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic) + accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic) s.Require().NoError(err, "ListByPlatform") s.Require().Len(accounts, 1) - s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) + s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform) } // --- Preload and VirtualFields --- func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { - proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) + proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{ + account := mustCreateAccount(s.T(), s.db, &accountModel{ Name: "acc1", ProxyID: &proxy.ID, }) @@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { // --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups --- func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { - g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) - g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"}) + g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"}) + g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"}) s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup") groups, err := s.repo.GetGroups(s.ctx, account.ID) @@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() { } func (s *AccountRepoSuite) TestBindGroups_EmptyList() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"}) mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1) s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty") @@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() { func (s *AccountRepoSuite) TestListSchedulable() { now := time.Now() - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"}) - okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) + okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true}) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) future := now.Add(10 * time.Minute) - overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) + overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future}) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) sched, err := s.repo.ListSchedulable(s.ctx) @@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() { func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() { now := time.Now() - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"}) - okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true}) + okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true}) mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1) future := now.Add(10 * time.Minute) - overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future}) + overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future}) mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1) - rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true}) + rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true}) mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1) s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited") @@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu } func (s *AccountRepoSuite) TestListSchedulableByPlatform() { - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true}) - accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic) + accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic) s.Require().NoError(err) s.Require().Len(accounts, 1) - s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform) + s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform) } func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"}) - a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true}) - a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"}) + a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true}) + a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true}) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) - accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic) + accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic) s.Require().NoError(err) s.Require().Len(accounts, 1) s.Require().Equal(a1.ID, accounts[0].ID) } func (s *AccountRepoSuite) TestSetSchedulable() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true}) s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false)) @@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() { // --- SetOverloaded / SetRateLimited / ClearRateLimit --- func (s *AccountRepoSuite) TestSetOverloaded() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"}) until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) @@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() { } func (s *AccountRepoSuite) TestSetRateLimited() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"}) resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt)) @@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() { } func (s *AccountRepoSuite) TestClearRateLimit() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"}) until := time.Now().Add(1 * time.Hour) s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until)) s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until)) @@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() { // --- UpdateLastUsed --- func (s *AccountRepoSuite) TestUpdateLastUsed() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"}) s.Require().Nil(account.LastUsedAt) s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID)) @@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() { // --- SetError --- func (s *AccountRepoSuite) TestSetError() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive}) s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) got, err := s.repo.GetByID(s.ctx, account.ID) s.Require().NoError(err) - s.Require().Equal(model.StatusError, got.Status) + s.Require().Equal(service.StatusError, got.Status) s.Require().Equal("something went wrong", got.ErrorMessage) } // --- UpdateSessionWindow --- func (s *AccountRepoSuite) TestUpdateSessionWindow() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"}) start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC) @@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() { // --- UpdateExtra --- func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { - account := mustCreateAccount(s.T(), s.db, &model.Account{ + account := mustCreateAccount(s.T(), s.db, &accountModel{ Name: "acc-extra", - Extra: model.JSONB{"a": "1"}, + Extra: datatypes.JSONMap{"a": "1"}, }) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra") @@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() { } func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"}) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{})) } func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil}) s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"})) got, err := s.repo.GetByID(s.ctx, account.ID) @@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() { func (s *AccountRepoSuite) TestGetByCRSAccountID() { crsID := "crs-12345" - mustCreateAccount(s.T(), s.db, &model.Account{ + mustCreateAccount(s.T(), s.db, &accountModel{ Name: "acc-crs", - Extra: model.JSONB{"crs_account_id": crsID}, + Extra: datatypes.JSONMap{"crs_account_id": crsID}, }) got, err := s.repo.GetByCRSAccountID(s.ctx, crsID) @@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() { // --- BulkUpdate --- func (s *AccountRepoSuite) TestBulkUpdate() { - a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1}) - a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1}) + a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1}) + a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1}) newPriority := 99 affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{ @@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() { } func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { - a1 := mustCreateAccount(s.T(), s.db, &model.Account{ + a1 := mustCreateAccount(s.T(), s.db, &accountModel{ Name: "bulk-cred", - Credentials: model.JSONB{"existing": "value"}, + Credentials: datatypes.JSONMap{"existing": "value"}, }) _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ - Credentials: model.JSONB{"new_key": "new_value"}, + Credentials: datatypes.JSONMap{"new_key": "new_value"}, }) s.Require().NoError(err) @@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() { } func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() { - a1 := mustCreateAccount(s.T(), s.db, &model.Account{ + a1 := mustCreateAccount(s.T(), s.db, &accountModel{ Name: "bulk-extra", - Extra: model.JSONB{"existing": "val"}, + Extra: datatypes.JSONMap{"existing": "val"}, }) _, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{ - Extra: model.JSONB{"new_key": "new_val"}, + Extra: datatypes.JSONMap{"new_key": "new_val"}, }) s.Require().NoError(err) @@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() { } func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() { - a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"}) + a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"}) affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{}) s.Require().NoError(err) s.Require().Zero(affected) } -func idsOfAccounts(accounts []model.Account) []int64 { +func idsOfAccounts(accounts []service.Account) []int64 { out := make([]int64, 0, len(accounts)) for i := range accounts { out = append(out, accounts[i].ID) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index e89fee75..718bef33 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,10 +2,10 @@ package repository import ( "context" + "time" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" @@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository { return &apiKeyRepository{db: db} } -func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { - err := r.db.WithContext(ctx).Create(key).Error +func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error { + m := apiKeyModelFromService(key) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyApiKeyModelToService(key, m) + } return translatePersistenceError(err, nil, service.ErrApiKeyExists) } -func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { - var key model.ApiKey - err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error +func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { + var m apiKeyModel + err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) } - return &key, nil + return apiKeyModelToService(&m), nil } -func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { - var apiKey model.ApiKey - err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error +func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { + var m apiKeyModel + err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) } - return &apiKey, nil + return apiKeyModelToService(&m), nil } -func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { - return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error +func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error { + m := apiKeyModelFromService(key) + err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error + if err == nil { + applyApiKeyModelToService(key, m) + } + return err } func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error + return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error } -func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { - var keys []model.ApiKey +func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + var keys []apiKeyModel var total int64 - db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID) + db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID) if err := db.Count(&total).Error; err != nil { return nil, nil, err @@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + outKeys := make([]service.ApiKey, 0, len(keys)) + for i := range keys { + outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) } - return keys, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outKeys, paginationResultFromTotal(total, params), nil } func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error + err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error return count, err } func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error + err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error return count > 0, err } -func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { - var keys []model.ApiKey +func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + var keys []apiKeyModel var total int64 - db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID) + db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID) if err := db.Count(&total).Error; err != nil { return nil, nil, err @@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + outKeys := make([]service.ApiKey, 0, len(keys)) + for i := range keys { + outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) } - return keys, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outKeys, paginationResultFromTotal(total, params), nil } // SearchApiKeys searches API keys by user ID and/or keyword (name) -func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { - var keys []model.ApiKey +func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { + var keys []apiKeyModel - db := r.db.WithContext(ctx).Model(&model.ApiKey{}) + db := r.db.WithContext(ctx).Model(&apiKeyModel{}) if userID > 0 { db = db.Where("user_id = ?", userID) @@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw return nil, err } - return keys, nil + outKeys := make([]service.ApiKey, 0, len(keys)) + for i := range keys { + outKeys = append(outKeys, *apiKeyModelToService(&keys[i])) + } + return outKeys, nil } // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { - result := r.db.WithContext(ctx).Model(&model.ApiKey{}). + result := r.db.WithContext(ctx).Model(&apiKeyModel{}). Where("group_id = ?", groupID). Update("group_id", nil) return result.RowsAffected, result.Error @@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in // CountByGroupID 获取分组的 API Key 数量 func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error + err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error return count, err } + +type apiKeyModel struct { + ID int64 `gorm:"primaryKey"` + UserID int64 `gorm:"index;not null"` + Key string `gorm:"uniqueIndex;size:128;not null"` + Name string `gorm:"size:100;not null"` + GroupID *int64 `gorm:"index"` + Status string `gorm:"size:20;default:active;not null"` + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + DeletedAt gorm.DeletedAt `gorm:"index"` + + User *userModel `gorm:"foreignKey:UserID"` + Group *groupModel `gorm:"foreignKey:GroupID"` +} + +func (apiKeyModel) TableName() string { return "api_keys" } + +func apiKeyModelToService(m *apiKeyModel) *service.ApiKey { + if m == nil { + return nil + } + return &service.ApiKey{ + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + GroupID: m.GroupID, + Status: m.Status, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + User: userModelToService(m.User), + Group: groupModelToService(m.Group), + } +} + +func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel { + if k == nil { + return nil + } + return &apiKeyModel{ + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + } +} + +func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) { + if key == nil || m == nil { + return + } + key.ID = m.ID + key.CreatedAt = m.CreatedAt + key.UpdatedAt = m.UpdatedAt +} diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 7a599ede..384ee364 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -6,8 +6,8 @@ import ( "context" "testing" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) { // --- Create / GetByID / GetByKey --- func (s *ApiKeyRepoSuite) TestCreate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"}) - key := &model.ApiKey{ + key := &service.ApiKey{ UserID: user.ID, Key: "sk-create-test", Name: "Test Key", - Status: model.StatusActive, + Status: service.StatusActive, } err := s.repo.Create(s.ctx, key) @@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() { } func (s *ApiKeyRepoSuite) TestGetByKey() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"}) - key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-getbykey", Name: "My Key", GroupID: &group.ID, - Status: model.StatusActive, + Status: service.StatusActive, }) got, err := s.repo.GetByKey(s.ctx, key.Key) @@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() { // --- Update --- func (s *ApiKeyRepoSuite) TestUpdate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) - key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"}) + key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-update", Name: "Original", - Status: model.StatusActive, - }) + Status: service.StatusActive, + })) key.Name = "Renamed" - key.Status = model.StatusDisabled + key.Status = service.StatusDisabled err := s.repo.Update(s.ctx, key) s.Require().NoError(err, "Update") @@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() { s.Require().Equal("sk-update", got.Key, "Update should not change key") s.Require().Equal(user.ID, got.UserID, "Update should not change user_id") s.Require().Equal("Renamed", got.Name) - s.Require().Equal(model.StatusDisabled, got.Status) + s.Require().Equal(service.StatusDisabled, got.Status) } func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"}) - key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"}) + key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-clear-group", Name: "Group Key", GroupID: &group.ID, - }) + })) key.GroupID = nil err := s.repo.Update(s.ctx, key) @@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() { // --- Delete --- func (s *ApiKeyRepoSuite) TestDelete() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) - key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"}) + key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-delete", Name: "Delete Me", @@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() { // --- ListByUserID / CountByUserID --- func (s *ApiKeyRepoSuite) TestListByUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"}) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "ListByUserID") @@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() { } func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"}) for i := 0; i < 5; i++ { - mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-page-" + string(rune('a'+i)), Name: "Key", @@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() { } func (s *ApiKeyRepoSuite) TestCountByUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"}) count, err := s.repo.CountByUserID(s.ctx, user.ID) s.Require().NoError(err, "CountByUserID") @@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() { // --- ListByGroupID / CountByGroupID --- func (s *ApiKeyRepoSuite) TestListByGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "ListByGroupID") @@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() { } func (s *ApiKeyRepoSuite) TestCountByGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID}) count, err := s.repo.CountByGroupID(s.ctx, group.ID) s.Require().NoError(err, "CountByGroupID") @@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() { // --- ExistsByKey --- func (s *ApiKeyRepoSuite) TestExistsByKey() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"}) exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists") s.Require().NoError(err, "ExistsByKey") @@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() { // --- SearchApiKeys --- func (s *ApiKeyRepoSuite) TestSearchApiKeys() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"}) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10) s.Require().NoError(err, "SearchApiKeys") @@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() { } func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"}) found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10) s.Require().NoError(err) @@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() { } func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"}) found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10) s.Require().NoError(err) @@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() { // --- ClearGroupIDByGroupID --- func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"}) - k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) - k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group + k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID}) + k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID) s.Require().NoError(err, "ClearGroupIDByGroupID") @@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() { // --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) --- func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"}) - key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-test-1", Name: "My Key", GroupID: &group.ID, - Status: model.StatusActive, - }) + Status: service.StatusActive, + })) got, err := s.repo.GetByKey(s.ctx, key.Key) s.Require().NoError(err, "GetByKey") @@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { s.Require().Equal(group.ID, got.Group.ID) key.Name = "Renamed" - key.Status = model.StatusDisabled + key.Status = service.StatusDisabled key.GroupID = nil s.Require().NoError(s.repo.Update(s.ctx, key), "Update") @@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { s.Require().Equal("sk-test-1", got2.Key, "Update should not change key") s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id") s.Require().Equal("Renamed", got2.Name) - s.Require().Equal(model.StatusDisabled, got2.Status) + s.Require().Equal(service.StatusDisabled, got2.Status) s.Require().Nil(got2.GroupID) keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) @@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() { s.Require().Equal(key.ID, found[0].ID) // ClearGroupIDByGroupID - k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{ + k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{ UserID: user.ID, Key: "sk-test-2", Name: "Group Key", diff --git a/backend/internal/repository/auto_migrate.go b/backend/internal/repository/auto_migrate.go new file mode 100644 index 00000000..6ca28036 --- /dev/null +++ b/backend/internal/repository/auto_migrate.go @@ -0,0 +1,20 @@ +package repository + +import "gorm.io/gorm" + +// AutoMigrate runs schema migrations for all repository persistence models. +// Persistence models are defined within individual `*_repo.go` files. +func AutoMigrate(db *gorm.DB) error { + return db.AutoMigrate( + &userModel{}, + &apiKeyModel{}, + &groupModel{}, + &accountModel{}, + &accountGroupModel{}, + &proxyModel{}, + &redeemCodeModel{}, + &usageLogModel{}, + &settingModel{}, + &userSubscriptionModel{}, + ) +} diff --git a/backend/internal/repository/fixtures_integration_test.go b/backend/internal/repository/fixtures_integration_test.go index adeb8ac6..72c5c0d5 100644 --- a/backend/internal/repository/fixtures_integration_test.go +++ b/backend/internal/repository/fixtures_integration_test.go @@ -6,21 +6,25 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" + "gorm.io/datatypes" "gorm.io/gorm" ) -func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { +func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel { t.Helper() if u.PasswordHash == "" { u.PasswordHash = "test-password-hash" } if u.Role == "" { - u.Role = model.RoleUser + u.Role = service.RoleUser } if u.Status == "" { - u.Status = model.StatusActive + u.Status = service.StatusActive + } + if u.Concurrency == 0 { + u.Concurrency = 5 } if u.CreatedAt.IsZero() { u.CreatedAt = time.Now() @@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User { return u } -func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { +func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel { t.Helper() if g.Platform == "" { - g.Platform = model.PlatformAnthropic + g.Platform = service.PlatformAnthropic } if g.Status == "" { - g.Status = model.StatusActive + g.Status = service.StatusActive } if g.SubscriptionType == "" { - g.SubscriptionType = model.SubscriptionTypeStandard + g.SubscriptionType = service.SubscriptionTypeStandard } if g.CreatedAt.IsZero() { g.CreatedAt = time.Now() @@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group { return g } -func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { +func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel { t.Helper() if p.Protocol == "" { p.Protocol = "http" @@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { p.Port = 8080 } if p.Status == "" { - p.Status = model.StatusActive + p.Status = service.StatusActive } if p.CreatedAt.IsZero() { p.CreatedAt = time.Now() @@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy { return p } -func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account { +func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel { t.Helper() if a.Platform == "" { - a.Platform = model.PlatformAnthropic + a.Platform = service.PlatformAnthropic } if a.Type == "" { - a.Type = model.AccountTypeOAuth + a.Type = service.AccountTypeOAuth } if a.Status == "" { - a.Status = model.StatusActive + a.Status = service.StatusActive } if !a.Schedulable { a.Schedulable = true } if a.Credentials == nil { - a.Credentials = model.JSONB{} + a.Credentials = datatypes.JSONMap{} } if a.Extra == nil { - a.Extra = model.JSONB{} + a.Extra = datatypes.JSONMap{} } if a.CreatedAt.IsZero() { a.CreatedAt = time.Now() @@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou return a } -func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey { +func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel { t.Helper() if k.Status == "" { - k.Status = model.StatusActive + k.Status = service.StatusActive } if k.CreatedAt.IsZero() { k.CreatedAt = time.Now() @@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey return k } -func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode { +func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel { t.Helper() if c.Status == "" { - c.Status = model.StatusUnused + c.Status = service.StatusUnused } if c.Type == "" { - c.Type = model.RedeemTypeBalance + c.Type = service.RedeemTypeBalance } if c.CreatedAt.IsZero() { c.CreatedAt = time.Now() @@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model return c } -func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription { +func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel { t.Helper() if s.Status == "" { - s.Status = model.SubscriptionStatusActive + s.Status = service.SubscriptionStatusActive } now := time.Now() if s.StartsAt.IsZero() { @@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) { t.Helper() - require.NoError(t, db.Create(&model.AccountGroup{ + require.NoError(t, db.Create(&accountGroupModel{ AccountID: accountID, GroupID: groupID, Priority: priority, + CreatedAt: time.Now(), }).Error, "create account_group") } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a2cb8e14..688d2655 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -2,10 +2,10 @@ package repository import ( "context" + "time" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" @@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository { return &groupRepository{db: db} } -func (r *groupRepository) Create(ctx context.Context, group *model.Group) error { - err := r.db.WithContext(ctx).Create(group).Error +func (r *groupRepository) Create(ctx context.Context, group *service.Group) error { + m := groupModelFromService(group) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyGroupModelToService(group, m) + } return translatePersistenceError(err, nil, service.ErrGroupExists) } -func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { - var group model.Group - err := r.db.WithContext(ctx).First(&group, id).Error +func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) { + var m groupModel + err := r.db.WithContext(ctx).First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } - return &group, nil + group := groupModelToService(&m) + count, _ := r.GetAccountCount(ctx, group.ID) + group.AccountCount = count + return group, nil } -func (r *groupRepository) Update(ctx context.Context, group *model.Group) error { - return r.db.WithContext(ctx).Save(group).Error +func (r *groupRepository) Update(ctx context.Context, group *service.Group) error { + m := groupModelFromService(group) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyGroupModelToService(group, m) + } + return err } func (r *groupRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error + return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error } -func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", nil) } // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive -func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { - var groups []model.Group +func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { + var groups []groupModel var total int64 - db := r.db.WithContext(ctx).Model(&model.Group{}) + db := r.db.WithContext(ctx).Model(&groupModel{}) // Apply filters if platform != "" { @@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination return nil, nil, err } - // 获取每个分组的账号数量 + outGroups := make([]service.Group, 0, len(groups)) for i := range groups { - count, _ := r.GetAccountCount(ctx, groups[i].ID) - groups[i].AccountCount = count + outGroups = append(outGroups, *groupModelToService(&groups[i])) } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + // 获取每个分组的账号数量 + for i := range outGroups { + count, _ := r.GetAccountCount(ctx, outGroups[i].ID) + outGroups[i].AccountCount = count } - return groups, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outGroups, paginationResultFromTotal(total, params), nil } -func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) { - var groups []model.Group - err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error +func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) { + var groups []groupModel + err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error if err != nil { return nil, err } - // 获取每个分组的账号数量 + outGroups := make([]service.Group, 0, len(groups)) for i := range groups { - count, _ := r.GetAccountCount(ctx, groups[i].ID) - groups[i].AccountCount = count + outGroups = append(outGroups, *groupModelToService(&groups[i])) } - return groups, nil + // 获取每个分组的账号数量 + for i := range outGroups { + count, _ := r.GetAccountCount(ctx, outGroups[i].ID) + outGroups[i].AccountCount = count + } + return outGroups, nil } -func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { - var groups []model.Group - err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error +func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + var groups []groupModel + err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error if err != nil { return nil, err } - // 获取每个分组的账号数量 + outGroups := make([]service.Group, 0, len(groups)) for i := range groups { - count, _ := r.GetAccountCount(ctx, groups[i].ID) - groups[i].AccountCount = count + outGroups = append(outGroups, *groupModelToService(&groups[i])) } - return groups, nil + // 获取每个分组的账号数量 + for i := range outGroups { + count, _ := r.GetAccountCount(ctx, outGroups[i].ID) + outGroups[i].AccountCount = count + } + return outGroups, nil } func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error + err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error return count > 0, err } func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error + err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error return count, err } // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { - result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) + result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID) return result.RowsAffected, result.Error } @@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, var affectedUserIDs []int64 if group.IsSubscriptionType() { - var subscriptions []model.UserSubscription if err := r.db.WithContext(ctx). - Model(&model.UserSubscription{}). + Table("user_subscriptions"). Where("group_id = ?", id). - Select("user_id"). - Find(&subscriptions).Error; err != nil { + Pluck("user_id", &affectedUserIDs).Error; err != nil { return nil, err } - for _, sub := range subscriptions { - affectedUserIDs = append(affectedUserIDs, sub.UserID) - } } err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // 1. 删除订阅类型分组的订阅记录 if group.IsSubscriptionType() { - if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { + if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil { return err } } // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil - if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { + if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil { return err } // 3. 从 users.allowed_groups 数组中移除该分组 ID - if err := tx.Model(&model.User{}). - Where("? = ANY(allowed_groups)", id). - Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { + if err := tx.Exec( + "UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)", + id, id, + ).Error; err != nil { return err } // 4. 删除 account_groups 中间表的数据 - if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { + if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil { return err } // 5. 删除分组本身(带锁,避免并发写) - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil { + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil { return err } @@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return affectedUserIDs, nil } + +type groupModel struct { + ID int64 `gorm:"primaryKey"` + Name string `gorm:"uniqueIndex;size:100;not null"` + Description string `gorm:"type:text"` + Platform string `gorm:"size:50;default:anthropic;not null"` + RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"` + IsExclusive bool `gorm:"default:false;not null"` + Status string `gorm:"size:20;default:active;not null"` + + SubscriptionType string `gorm:"size:20;default:standard;not null"` + DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"` + WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"` + MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"` + + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +func (groupModel) TableName() string { return "groups" } + +func groupModelToService(m *groupModel) *service.Group { + if m == nil { + return nil + } + return &service.Group{ + ID: m.ID, + Name: m.Name, + Description: m.Description, + Platform: m.Platform, + RateMultiplier: m.RateMultiplier, + IsExclusive: m.IsExclusive, + Status: m.Status, + SubscriptionType: m.SubscriptionType, + DailyLimitUSD: m.DailyLimitUSD, + WeeklyLimitUSD: m.WeeklyLimitUSD, + MonthlyLimitUSD: m.MonthlyLimitUSD, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } +} + +func groupModelFromService(sg *service.Group) *groupModel { + if sg == nil { + return nil + } + return &groupModel{ + ID: sg.ID, + Name: sg.Name, + Description: sg.Description, + Platform: sg.Platform, + RateMultiplier: sg.RateMultiplier, + IsExclusive: sg.IsExclusive, + Status: sg.Status, + SubscriptionType: sg.SubscriptionType, + DailyLimitUSD: sg.DailyLimitUSD, + WeeklyLimitUSD: sg.WeeklyLimitUSD, + MonthlyLimitUSD: sg.MonthlyLimitUSD, + CreatedAt: sg.CreatedAt, + UpdatedAt: sg.UpdatedAt, + } +} + +func applyGroupModelToService(group *service.Group, m *groupModel) { + if group == nil || m == nil { + return + } + group.ID = m.ID + group.CreatedAt = m.CreatedAt + group.UpdatedAt = m.UpdatedAt +} diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 85fd27b2..33ff6326 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -6,8 +6,8 @@ import ( "context" "testing" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) { // --- Create / GetByID / Update / Delete --- func (s *GroupRepoSuite) TestCreate() { - group := &model.Group{ + group := &service.Group{ Name: "test-create", - Platform: model.PlatformAnthropic, - Status: model.StatusActive, + Platform: service.PlatformAnthropic, + Status: service.StatusActive, } err := s.repo.Create(s.ctx, group) @@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { } func (s *GroupRepoSuite) TestUpdate() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"}) + group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"})) group.Name = "updated" err := s.repo.Update(s.ctx, group) @@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() { } func (s *GroupRepoSuite) TestDelete() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"}) err := s.repo.Delete(s.ctx, group.ID) s.Require().NoError(err, "Delete") @@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() { // --- List / ListWithFilters --- func (s *GroupRepoSuite) TestList() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"}) groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") @@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() { } func (s *GroupRepoSuite) TestListWithFilters_Platform() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI}) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) s.Require().NoError(err) s.Require().Len(groups, 1) - s.Require().Equal(model.PlatformOpenAI, groups[0].Platform) + s.Require().Equal(service.PlatformOpenAI, groups[0].Platform) } func (s *GroupRepoSuite) TestListWithFilters_Status() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled}) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) s.Require().NoError(err) s.Require().Len(groups, 1) - s.Require().Equal(model.StatusDisabled, groups[0].Status) + s.Require().Equal(service.StatusDisabled, groups[0].Status) } func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true}) isExclusive := true groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) @@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { } func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { - g1 := mustCreateGroup(s.T(), s.db, &model.Group{ + g1 := mustCreateGroup(s.T(), s.db, &groupModel{ Name: "g1", - Platform: model.PlatformAnthropic, - Status: model.StatusActive, + Platform: service.PlatformAnthropic, + Status: service.StatusActive, }) - g2 := mustCreateGroup(s.T(), s.db, &model.Group{ + g2 := mustCreateGroup(s.T(), s.db, &groupModel{ Name: "g2", - Platform: model.PlatformAnthropic, - Status: model.StatusActive, + Platform: service.PlatformAnthropic, + Status: service.StatusActive, IsExclusive: true, }) - a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"}) + a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"}) mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1) mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1) isExclusive := true - groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive) + groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(groups, 1) @@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { // --- ListActive / ListActiveByPlatform --- func (s *GroupRepoSuite) TestListActive() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled}) groups, err := s.repo.ListActive(s.ctx) s.Require().NoError(err, "ListActive") @@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() { } func (s *GroupRepoSuite) TestListActiveByPlatform() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive}) - mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled}) - groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic) + groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic) s.Require().NoError(err, "ListActiveByPlatform") s.Require().Len(groups, 1) s.Require().Equal("g1", groups[0].Name) @@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() { // --- ExistsByName --- func (s *GroupRepoSuite) TestExistsByName() { - mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"}) + mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"}) exists, err := s.repo.ExistsByName(s.ctx, "existing-group") s.Require().NoError(err, "ExistsByName") @@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() { // --- GetAccountCount --- func (s *GroupRepoSuite) TestGetAccountCount() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) - a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) - a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"}) + a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"}) + a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"}) mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1) mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2) @@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { } func (s *GroupRepoSuite) TestGetAccountCount_Empty() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"}) count, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err) @@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { // --- DeleteAccountGroupsByGroupID --- func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { - g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"}) - a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"}) + g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"}) + a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"}) mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1) affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID) @@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { } func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { - g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"}) - a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"}) - a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"}) - a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) + g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"}) + a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"}) + a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"}) + a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1) mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2) mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3) diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index 1588e078..ab248d06 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -94,7 +93,7 @@ func TestMain(m *testing.M) { log.Printf("failed to open gorm db: %v", err) os.Exit(1) } - if err := model.AutoMigrate(integrationDB); err != nil { + if err := AutoMigrate(integrationDB); err != nil { log.Printf("failed to automigrate db: %v", err) os.Exit(1) } diff --git a/backend/internal/repository/pagination.go b/backend/internal/repository/pagination.go new file mode 100644 index 00000000..ff08c34b --- /dev/null +++ b/backend/internal/repository/pagination.go @@ -0,0 +1,16 @@ +package repository + +import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + +func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult { + pages := int(total) / params.Limit() + if int(total)%params.Limit() > 0 { + pages++ + } + return &pagination.PaginationResult{ + Total: total, + Page: params.Page, + PageSize: params.Limit(), + Pages: pages, + } +} diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index 590c6a61..423584fb 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -2,10 +2,10 @@ package repository import ( "context" + "time" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" @@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository { return &proxyRepository{db: db} } -func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { - return r.db.WithContext(ctx).Create(proxy).Error +func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error { + m := proxyModelFromService(proxy) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyProxyModelToService(proxy, m) + } + return err } -func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { - var proxy model.Proxy - err := r.db.WithContext(ctx).First(&proxy, id).Error +func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) { + var m proxyModel + err := r.db.WithContext(ctx).First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil) } - return &proxy, nil + return proxyModelToService(&m), nil } -func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { - return r.db.WithContext(ctx).Save(proxy).Error +func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error { + m := proxyModelFromService(proxy) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyProxyModelToService(proxy, m) + } + return err } func (r *proxyRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error + return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error } -func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { +func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists proxies with optional filtering by protocol, status, and search query -func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { - var proxies []model.Proxy +func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) { + var proxies []proxyModel var total int64 - db := r.db.WithContext(ctx).Model(&model.Proxy{}) + db := r.db.WithContext(ctx).Model(&proxyModel{}) // Apply filters if protocol != "" { @@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + outProxies := make([]service.Proxy, 0, len(proxies)) + for i := range proxies { + outProxies = append(outProxies, *proxyModelToService(&proxies[i])) } - return proxies, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outProxies, paginationResultFromTotal(total, params), nil } -func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { - var proxies []model.Proxy - err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error - return proxies, err +func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { + var proxies []proxyModel + err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error + if err != nil { + return nil, err + } + outProxies := make([]service.Proxy, 0, len(proxies)) + for i := range proxies { + outProxies = append(outProxies, *proxyModelToService(&proxies[i])) + } + return outProxies, nil } // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.Proxy{}). + err := r.db.WithContext(ctx).Model(&proxyModel{}). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Count(&count).Error if err != nil { @@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, // CountAccountsByProxyID returns the number of accounts using a specific proxy func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.Account{}). + err := r.db.WithContext(ctx).Table("accounts"). Where("proxy_id = ?", proxyID). Count(&count).Error return count, err @@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i } var results []result err := r.db.WithContext(ctx). - Model(&model.Account{}). + Table("accounts"). Select("proxy_id, COUNT(*) as count"). Where("proxy_id IS NOT NULL"). Group("proxy_id"). @@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i } // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending -func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { - var proxies []model.Proxy +func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) { + var proxies []proxyModel err := r.db.WithContext(ctx). - Where("status = ?", model.StatusActive). + Where("status = ?", service.StatusActive). Order("created_at DESC"). Find(&proxies).Error if err != nil { @@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod } // Build result with account counts - result := make([]model.ProxyWithAccountCount, len(proxies)) - for i, proxy := range proxies { - result[i] = model.ProxyWithAccountCount{ - Proxy: proxy, - AccountCount: counts[proxy.ID], + result := make([]service.ProxyWithAccountCount, 0, len(proxies)) + for i := range proxies { + proxy := proxyModelToService(&proxies[i]) + if proxy == nil { + continue } + result = append(result, service.ProxyWithAccountCount{ + Proxy: *proxy, + AccountCount: counts[proxy.ID], + }) } return result, nil } + +type proxyModel struct { + ID int64 `gorm:"primaryKey"` + Name string `gorm:"size:100;not null"` + Protocol string `gorm:"size:20;not null"` + Host string `gorm:"size:255;not null"` + Port int `gorm:"not null"` + Username string `gorm:"size:100"` + Password string `gorm:"size:100"` + Status string `gorm:"size:20;default:active;not null"` + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +func (proxyModel) TableName() string { return "proxies" } + +func proxyModelToService(m *proxyModel) *service.Proxy { + if m == nil { + return nil + } + return &service.Proxy{ + ID: m.ID, + Name: m.Name, + Protocol: m.Protocol, + Host: m.Host, + Port: m.Port, + Username: m.Username, + Password: m.Password, + Status: m.Status, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } +} + +func proxyModelFromService(p *service.Proxy) *proxyModel { + if p == nil { + return nil + } + return &proxyModel{ + ID: p.ID, + Name: p.Name, + Protocol: p.Protocol, + Host: p.Host, + Port: p.Port, + Username: p.Username, + Password: p.Password, + Status: p.Status, + CreatedAt: p.CreatedAt, + UpdatedAt: p.UpdatedAt, + } +} + +func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) { + if proxy == nil || m == nil { + return + } + proxy.ID = m.ID + proxy.CreatedAt = m.CreatedAt + proxy.UpdatedAt = m.UpdatedAt +} diff --git a/backend/internal/repository/proxy_repo_integration_test.go b/backend/internal/repository/proxy_repo_integration_test.go index 9e773398..3aa02176 100644 --- a/backend/internal/repository/proxy_repo_integration_test.go +++ b/backend/internal/repository/proxy_repo_integration_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) { // --- Create / GetByID / Update / Delete --- func (s *ProxyRepoSuite) TestCreate() { - proxy := &model.Proxy{ + proxy := &service.Proxy{ Name: "test-create", Protocol: "http", Host: "127.0.0.1", Port: 8080, - Status: model.StatusActive, + Status: service.StatusActive, } err := s.repo.Create(s.ctx, proxy) @@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() { } func (s *ProxyRepoSuite) TestUpdate() { - proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"}) + proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"})) proxy.Name = "updated" err := s.repo.Update(s.ctx, proxy) @@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() { } func (s *ProxyRepoSuite) TestDelete() { - proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"}) + proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"}) err := s.repo.Delete(s.ctx, proxy.ID) s.Require().NoError(err, "Delete") @@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() { // --- List / ListWithFilters --- func (s *ProxyRepoSuite) TestList() { - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"}) proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") @@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() { } func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"}) - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"}) proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "") s.Require().NoError(err) @@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() { } func (s *ProxyRepoSuite) TestListWithFilters_Status() { - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive}) - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled}) - proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "") + proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "") s.Require().NoError(err) s.Require().Len(proxies, 1) - s.Require().Equal(model.StatusDisabled, proxies[0].Status) + s.Require().Equal(service.StatusDisabled, proxies[0].Status) } func (s *ProxyRepoSuite) TestListWithFilters_Search() { - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"}) - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"}) proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod") s.Require().NoError(err) @@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() { // --- ListActive --- func (s *ProxyRepoSuite) TestListActive() { - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive}) - mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive}) + mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled}) proxies, err := s.repo.ListActive(s.ctx) s.Require().NoError(err, "ListActive") @@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() { // --- ExistsByHostPortAuth --- func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { - mustCreateProxy(s.T(), s.db, &model.Proxy{ + mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p1", Protocol: "http", Host: "1.2.3.4", @@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() { } func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { - mustCreateProxy(s.T(), s.db, &model.Proxy{ + mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p-noauth", Protocol: "http", Host: "5.6.7.8", @@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() { // --- CountAccountsByProxyID --- func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { - proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy + proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) s.Require().NoError(err, "CountAccountsByProxyID") @@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() { } func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { - proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"}) + proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"}) count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID) s.Require().NoError(err) @@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() { // --- GetAccountCountsForProxies --- func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() { - p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"}) - p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"}) + p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"}) + p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID}) counts, err := s.repo.GetAccountCountsForProxies(s.ctx) s.Require().NoError(err, "GetAccountCountsForProxies") @@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() { func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + p1 := mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p1", - Status: model.StatusActive, + Status: service.StatusActive, CreatedAt: base.Add(-1 * time.Hour), }) - p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + p2 := mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p2", - Status: model.StatusActive, + Status: service.StatusActive, CreatedAt: base, }) - mustCreateProxy(s.T(), s.db, &model.Proxy{ + mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p3-inactive", - Status: model.StatusDisabled, + Status: service.StatusDisabled, }) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID}) withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx) s.Require().NoError(err, "ListActiveWithAccountCount") @@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() { // --- Combined original test --- func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { - p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + p1 := mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p1", Protocol: "http", Host: "1.2.3.4", @@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { CreatedAt: time.Now().Add(-1 * time.Hour), UpdatedAt: time.Now().Add(-1 * time.Hour), }) - p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{ + p2 := mustCreateProxy(s.T(), s.db, &proxyModel{ Name: "p2", Protocol: "http", Host: "5.6.7.8", @@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() { s.Require().NoError(err, "ExistsByHostPortAuth") s.Require().True(exists, "expected proxy to exist") - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID}) count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID) s.Require().NoError(err, "CountAccountsByProxyID") diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index aa6e7010..957f2677 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -4,10 +4,8 @@ import ( "context" "time" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "gorm.io/gorm" ) @@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository { return &redeemCodeRepository{db: db} } -func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { - return r.db.WithContext(ctx).Create(code).Error +func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error { + m := redeemCodeModelFromService(code) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyRedeemCodeModelToService(code, m) + } + return err } -func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { - return r.db.WithContext(ctx).Create(&codes).Error +func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error { + if len(codes) == 0 { + return nil + } + models := make([]redeemCodeModel, 0, len(codes)) + for i := range codes { + m := redeemCodeModelFromService(&codes[i]) + if m != nil { + models = append(models, *m) + } + } + return r.db.WithContext(ctx).Create(&models).Error } -func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { - var code model.RedeemCode - err := r.db.WithContext(ctx).First(&code, id).Error +func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) { + var m redeemCodeModel + err := r.db.WithContext(ctx).First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) } - return &code, nil + return redeemCodeModelToService(&m), nil } -func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { - var redeemCode model.RedeemCode - err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error +func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + var m redeemCodeModel + err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) } - return &redeemCode, nil + return redeemCodeModelToService(&m), nil } func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error + return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error } -func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { +func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } -// ListWithFilters lists redeem codes with optional filtering by type, status, and search query -func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { - var codes []model.RedeemCode +func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + var codes []redeemCodeModel var total int64 - db := r.db.WithContext(ctx).Model(&model.RedeemCode{}) + db := r.db.WithContext(ctx).Model(&redeemCodeModel{}) - // Apply filters if codeType != "" { db = db.Where("type = ?", codeType) } @@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + outCodes := make([]service.RedeemCode, 0, len(codes)) + for i := range codes { + outCodes = append(outCodes, *redeemCodeModelToService(&codes[i])) } - return codes, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outCodes, paginationResultFromTotal(total, params), nil } -func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { - return r.db.WithContext(ctx).Save(code).Error +func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error { + m := redeemCodeModelFromService(code) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyRedeemCodeModelToService(code, m) + } + return err } func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { now := time.Now() - result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). - Where("id = ? AND status = ?", id, model.StatusUnused). + result := r.db.WithContext(ctx).Model(&redeemCodeModel{}). + Where("id = ? AND status = ?", id, service.StatusUnused). Updates(map[string]any{ - "status": model.StatusUsed, + "status": service.StatusUsed, "used_by": userID, "used_at": now, }) @@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error return nil } -// ListByUser returns all redeem codes used by a specific user -func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { - var codes []model.RedeemCode +func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { if limit <= 0 { limit = 10 } + var codes []redeemCodeModel err := r.db.WithContext(ctx). Preload("Group"). Where("used_by = ?", userID). Order("used_at DESC"). Limit(limit). Find(&codes).Error - if err != nil { return nil, err } - return codes, nil + + outCodes := make([]service.RedeemCode, 0, len(codes)) + for i := range codes { + outCodes = append(outCodes, *redeemCodeModelToService(&codes[i])) + } + return outCodes, nil +} + +type redeemCodeModel struct { + ID int64 `gorm:"primaryKey"` + Code string `gorm:"uniqueIndex;size:32;not null"` + Type string `gorm:"size:20;default:balance;not null"` + Value float64 `gorm:"type:decimal(20,8);not null"` + Status string `gorm:"size:20;default:unused;not null"` + UsedBy *int64 `gorm:"index"` + UsedAt *time.Time + Notes string `gorm:"type:text"` + CreatedAt time.Time `gorm:"not null"` + + GroupID *int64 `gorm:"index"` + ValidityDays int `gorm:"default:30"` + + User *userModel `gorm:"foreignKey:UsedBy"` + Group *groupModel `gorm:"foreignKey:GroupID"` +} + +func (redeemCodeModel) TableName() string { return "redeem_codes" } + +func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode { + if m == nil { + return nil + } + return &service.RedeemCode{ + ID: m.ID, + Code: m.Code, + Type: m.Type, + Value: m.Value, + Status: m.Status, + UsedBy: m.UsedBy, + UsedAt: m.UsedAt, + Notes: m.Notes, + CreatedAt: m.CreatedAt, + GroupID: m.GroupID, + ValidityDays: m.ValidityDays, + User: userModelToService(m.User), + Group: groupModelToService(m.Group), + } +} + +func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel { + if r == nil { + return nil + } + return &redeemCodeModel{ + ID: r.ID, + Code: r.Code, + Type: r.Type, + Value: r.Value, + Status: r.Status, + UsedBy: r.UsedBy, + UsedAt: r.UsedAt, + Notes: r.Notes, + CreatedAt: r.CreatedAt, + GroupID: r.GroupID, + ValidityDays: r.ValidityDays, + } +} + +func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) { + if code == nil || m == nil { + return + } + code.ID = m.ID + code.CreatedAt = m.CreatedAt } diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go index 5151f7e2..50427163 100644 --- a/backend/internal/repository/redeem_code_repo_integration_test.go +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" @@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) { // --- Create / CreateBatch / GetByID / GetByCode --- func (s *RedeemCodeRepoSuite) TestCreate() { - code := &model.RedeemCode{ + code := &service.RedeemCode{ Code: "TEST-CREATE", - Type: model.RedeemTypeBalance, + Type: service.RedeemTypeBalance, Value: 100, - Status: model.StatusUnused, + Status: service.StatusUnused, } err := s.repo.Create(s.ctx, code) @@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() { } func (s *RedeemCodeRepoSuite) TestCreateBatch() { - codes := []model.RedeemCode{ - {Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused}, - {Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused}, + codes := []service.RedeemCode{ + {Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}, + {Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}, } err := s.repo.CreateBatch(s.ctx, codes) @@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() { } func (s *RedeemCodeRepoSuite) TestGetByCode() { - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance}) got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE") s.Require().NoError(err, "GetByCode") @@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() { // --- Delete --- func (s *RedeemCodeRepoSuite) TestDelete() { - code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance}) + code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance}) err := s.repo.Delete(s.ctx, code.ID) s.Require().NoError(err, "Delete") @@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() { // --- List / ListWithFilters --- func (s *RedeemCodeRepoSuite) TestList() { - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance}) - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance}) codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") @@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() { } func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() { - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance}) - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription}) - codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "") + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "") s.Require().NoError(err) s.Require().Len(codes, 1) - s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type) + s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type) } func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed}) - codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "") + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "") s.Require().NoError(err) s.Require().Len(codes, 1) - s.Require().Equal(model.StatusUsed, codes[0].Status) + s.Require().Equal(service.StatusUsed, codes[0].Status) } func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance}) - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance}) codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha") s.Require().NoError(err) @@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { } func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) - mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"}) + mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{ Code: "WITH-GROUP", - Type: model.RedeemTypeSubscription, + Type: service.RedeemTypeSubscription, GroupID: &group.ID, }) @@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() { // --- Update --- func (s *RedeemCodeRepoSuite) TestUpdate() { - code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10}) + code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10})) code.Value = 50 err := s.repo.Update(s.ctx, code) @@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() { // --- Use --- func (s *RedeemCodeRepoSuite) TestUse() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"}) - code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused}) err := s.repo.Use(s.ctx, code.ID, user.ID) s.Require().NoError(err, "Use") got, err := s.repo.GetByID(s.ctx, code.ID) s.Require().NoError(err) - s.Require().Equal(model.StatusUsed, got.Status) + s.Require().Equal(service.StatusUsed, got.Status) s.Require().NotNil(got.UsedBy) s.Require().Equal(user.ID, *got.UsedBy) s.Require().NotNil(got.UsedAt) } func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"}) - code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused}) err := s.repo.Use(s.ctx, code.ID, user.ID) s.Require().NoError(err, "Use first time") @@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { } func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"}) - code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"}) + code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed}) err := s.repo.Use(s.ctx, code.ID, user.ID) s.Require().Error(err, "expected error for already used code") @@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { // --- ListByUser --- func (s *RedeemCodeRepoSuite) TestListByUser() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"}) base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) // Create codes with explicit used_at for ordering - c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{ Code: "USER-1", - Type: model.RedeemTypeBalance, - Status: model.StatusUsed, + Type: service.RedeemTypeBalance, + Status: service.StatusUsed, UsedBy: &user.ID, }) s.db.Model(c1).Update("used_at", base) - c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{ Code: "USER-2", - Type: model.RedeemTypeBalance, - Status: model.StatusUsed, + Type: service.RedeemTypeBalance, + Status: service.StatusUsed, UsedBy: &user.ID, }) s.db.Model(c2).Update("used_at", base.Add(1*time.Hour)) @@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() { } func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"}) - c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{ Code: "WITH-GRP", - Type: model.RedeemTypeSubscription, - Status: model.StatusUsed, + Type: service.RedeemTypeSubscription, + Status: service.StatusUsed, UsedBy: &user.ID, GroupID: &group.ID, }) @@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() { } func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"}) - c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"}) + c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{ Code: "DEF-LIM", - Type: model.RedeemTypeBalance, - Status: model.StatusUsed, + Type: service.RedeemTypeBalance, + Status: service.StatusUsed, UsedBy: &user.ID, }) s.db.Model(c).Update("used_at", time.Now()) @@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() { // --- Combined original test --- func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"}) - codes := []model.RedeemCode{ - {Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()}, - {Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, + codes := []service.RedeemCode{ + {Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()}, + {Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()}, } s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch") - list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code") + list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(list, 1) @@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser s.Require().NoError(err, "GetByCode") // Use fixed time instead of time.Sleep for deterministic ordering - s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) + s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)) s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA") - s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) + s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)) used, err := s.repo.ListByUser(s.ctx, user.ID, 10) s.Require().NoError(err, "ListByUser") diff --git a/backend/internal/repository/setting_repo.go b/backend/internal/repository/setting_repo.go index 43dd65d4..00d3776e 100644 --- a/backend/internal/repository/setting_repo.go +++ b/backend/internal/repository/setting_repo.go @@ -6,33 +6,27 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" - "gorm.io/gorm" "gorm.io/gorm/clause" ) -// SettingRepository 系统设置数据访问层 type settingRepository struct { db *gorm.DB } -// NewSettingRepository 创建系统设置仓库实例 func NewSettingRepository(db *gorm.DB) service.SettingRepository { return &settingRepository{db: db} } -// Get 根据Key获取设置值 -func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { - var setting model.Setting - err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error +func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) { + var m settingModel + err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil) } - return &setting, nil + return settingModelToService(&m), nil } -// GetValue 获取设置值字符串 func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) { setting, err := r.Get(ctx, key) if err != nil { @@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e return setting.Value, nil } -// Set 设置值(存在则更新,不存在则创建) func (r *settingRepository) Set(ctx context.Context, key, value string) error { - setting := &model.Setting{ + m := &settingModel{ Key: key, Value: value, UpdatedAt: time.Now(), @@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error { return r.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), - }).Create(setting).Error + }).Create(m).Error } -// GetMultiple 批量获取设置 func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - var settings []model.Setting + var settings []settingModel err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error if err != nil { return nil, err @@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map return result, nil } -// SetMultiple 批量设置值 func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { for key, value := range settings { - setting := &model.Setting{ + m := &settingModel{ Key: key, Value: value, UpdatedAt: time.Now(), @@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "key"}}, DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}), - }).Create(setting).Error; err != nil { + }).Create(m).Error; err != nil { return err } } @@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string }) } -// GetAll 获取所有设置 func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) { - var settings []model.Setting + var settings []settingModel err := r.db.WithContext(ctx).Find(&settings).Error if err != nil { return nil, err @@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro return result, nil } -// Delete 删除设置 func (r *settingRepository) Delete(ctx context.Context, key string) error { - return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error + return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error +} + +type settingModel struct { + ID int64 `gorm:"primaryKey"` + Key string `gorm:"uniqueIndex;size:100;not null"` + Value string `gorm:"type:text;not null"` + UpdatedAt time.Time `gorm:"not null"` +} + +func (settingModel) TableName() string { return "settings" } + +func settingModelToService(m *settingModel) *service.Setting { + if m == nil { + return nil + } + return &service.Setting{ + ID: m.ID, + Key: m.Key, + Value: m.Value, + UpdatedAt: m.UpdatedAt, + } } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 038719ae..f16e5fd7 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -6,7 +6,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" @@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int TokenCount int64 `gorm:"column:token_count"` } - db := r.db.WithContext(ctx).Model(&model.UsageLog{}). + db := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as request_count, COALESCE(SUM(input_tokens + output_tokens), 0) as token_count @@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int return perfStats.RequestCount / 5, perfStats.TokenCount / 5 } -func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { - return r.db.WithContext(ctx).Create(log).Error +func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error { + m := usageLogModelFromService(log) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyUsageLogModelToService(log, m) + } + return err } -func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { - var log model.UsageLog +func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { + var log usageLogModel err := r.db.WithContext(ctx).First(&log, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil) } - return &log, nil + return usageLogModelToService(&log), nil } -func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel var total int64 - db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID) + db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID) if err := db.Count(&total).Error; err != nil { return nil, nil, err @@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return logs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil } -func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel var total int64 - db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID) + db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID) if err := db.Count(&total).Error; err != nil { return nil, nil, err @@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return logs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil } // UserStats 用户使用统计 @@ -125,7 +109,7 @@ type UserStats struct { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { var stats UserStats - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as total_requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, @@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS today := timezone.Today() // 总用户数 - r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers) + r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers) // 今日新增用户数 - r.db.WithContext(ctx).Model(&model.User{}). + r.db.WithContext(ctx).Model(&userModel{}). Where("created_at >= ?", today). Count(&stats.TodayNewUsers) // 今日活跃用户数 (今日有请求的用户) - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Distinct("user_id"). Where("created_at >= ?", today). Count(&stats.ActiveUsers) // 总 API Key 数 - r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys) + r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys) // 活跃 API Key 数 - r.db.WithContext(ctx).Model(&model.ApiKey{}). - Where("status = ?", model.StatusActive). + r.db.WithContext(ctx).Model(&apiKeyModel{}). + Where("status = ?", service.StatusActive). Count(&stats.ActiveApiKeys) // 总账户数 - r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts) + r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts) // 正常账户数 (schedulable=true, status=active) - r.db.WithContext(ctx).Model(&model.Account{}). - Where("status = ? AND schedulable = ?", model.StatusActive, true). + r.db.WithContext(ctx).Model(&accountModel{}). + Where("status = ? AND schedulable = ?", service.StatusActive, true). Count(&stats.NormalAccounts) // 异常账户数 (status=error) - r.db.WithContext(ctx).Model(&model.Account{}). - Where("status = ?", model.StatusError). + r.db.WithContext(ctx).Model(&accountModel{}). + Where("status = ?", service.StatusError). Count(&stats.ErrorAccounts) // 限流账户数 - r.db.WithContext(ctx).Model(&model.Account{}). + r.db.WithContext(ctx).Model(&accountModel{}). Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()). Count(&stats.RateLimitAccounts) // 过载账户数 - r.db.WithContext(ctx).Model(&model.Account{}). + r.db.WithContext(ctx).Model(&accountModel{}). Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()). Count(&stats.OverloadAccounts) @@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS TotalActualCost float64 `gorm:"column:total_actual_cost"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"` } - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS TodayCost float64 `gorm:"column:today_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"` } - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as today_requests, COALESCE(SUM(input_tokens), 0) as today_input_tokens, @@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS return &stats, nil } -func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel var total int64 - db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID) + db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID) if err := db.Count(&total).Error; err != nil { return nil, nil, err @@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return logs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil } -func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel err := r.db.WithContext(ctx). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Order("id DESC"). Find(&logs).Error - return logs, nil, err + return usageLogModelsToService(logs), nil, err } -func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel err := r.db.WithContext(ctx). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Order("id DESC"). Find(&logs).Error - return logs, nil, err + return usageLogModelsToService(logs), nil, err } -func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel err := r.db.WithContext(ctx). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Order("id DESC"). Find(&logs).Error - return logs, nil, err + return usageLogModelsToService(logs), nil, err } -func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel err := r.db.WithContext(ctx). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Order("id DESC"). Find(&logs).Error - return logs, nil, err + return usageLogModelsToService(logs), nil, err } func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error + return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error } // GetAccountTodayStats 获取账号今日统计 @@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID Cost float64 `gorm:"column:cost"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, @@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI Cost float64 `gorm:"column:cost"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, @@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i today := timezone.Today() // API Key 统计 - r.db.WithContext(ctx).Model(&model.ApiKey{}). + r.db.WithContext(ctx).Model(&apiKeyModel{}). Where("user_id = ?", userID). Count(&stats.TotalApiKeys) - r.db.WithContext(ctx).Model(&model.ApiKey{}). - Where("user_id = ? AND status = ?", userID, model.StatusActive). + r.db.WithContext(ctx).Model(&apiKeyModel{}). + Where("user_id = ? AND status = ?", userID, service.StatusActive). Count(&stats.ActiveApiKeys) // 累计 Token 统计 @@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i TotalActualCost float64 `gorm:"column:total_actual_cost"` AverageDurationMs float64 `gorm:"column:avg_duration_ms"` } - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i TodayCost float64 `gorm:"column:today_cost"` TodayActualCost float64 `gorm:"column:today_actual_cost"` } - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as today_requests, COALESCE(SUM(input_tokens), 0) as today_input_tokens, @@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user dateFormat = "YYYY-MM-DD" } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` TO_CHAR(created_at, ?) as date, COUNT(*) as requests, @@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { var results []ModelStat - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` model, COUNT(*) as requests, @@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 type UsageLogFilters = usagestats.UsageLogFilters // ListWithFilters lists usage logs with optional filters (for admin) -func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { - var logs []model.UsageLog +func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + var logs []usageLogModel var total int64 - db := r.db.WithContext(ctx).Model(&model.UsageLog{}) + db := r.db.WithContext(ctx).Model(&usageLogModel{}) // Apply filters if filters.UserID > 0 { @@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return logs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil } // UsageStats represents usage statistics @@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs UserID int64 `gorm:"column:user_id"` TotalCost float64 `gorm:"column:total_cost"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Where("user_id IN ?", userIDs). Group("user_id"). @@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs UserID int64 `gorm:"column:user_id"` TodayCost float64 `gorm:"column:today_cost"` } - err = r.db.WithContext(ctx).Model(&model.UsageLog{}). + err = r.db.WithContext(ctx).Model(&usageLogModel{}). Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Where("user_id IN ? AND created_at >= ?", userIDs, today). Group("user_id"). @@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ApiKeyID int64 `gorm:"column:api_key_id"` TotalCost float64 `gorm:"column:total_cost"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost"). Where("api_key_id IN ?", apiKeyIDs). Group("api_key_id"). @@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe ApiKeyID int64 `gorm:"column:api_key_id"` TodayCost float64 `gorm:"column:today_cost"` } - err = r.db.WithContext(ctx).Model(&model.UsageLog{}). + err = r.db.WithContext(ctx).Model(&usageLogModel{}). Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost"). Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today). Group("api_key_id"). @@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start dateFormat = "YYYY-MM-DD" } - db := r.db.WithContext(ctx).Model(&model.UsageLog{}). + db := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` TO_CHAR(created_at, ?) as date, COUNT(*) as requests, @@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { var results []ModelStat - db := r.db.WithContext(ctx).Model(&model.UsageLog{}). + db := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` model, COUNT(*) as requests, @@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT AverageDurationMs float64 `gorm:"column:avg_duration_ms"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` COUNT(*) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, @@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID ActualCost float64 `gorm:"column:actual_cost"` } - err := r.db.WithContext(ctx).Model(&model.UsageLog{}). + err := r.db.WithContext(ctx).Model(&usageLogModel{}). Select(` TO_CHAR(created_at, 'YYYY-MM-DD') as date, COUNT(*) as requests, @@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID var avgDuration struct { AvgDurationMs float64 `gorm:"column:avg_duration_ms"` } - r.db.WithContext(ctx).Model(&model.UsageLog{}). + r.db.WithContext(ctx).Model(&usageLogModel{}). Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms"). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Scan(&avgDuration) @@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID Models: models, }, nil } + +type usageLogModel struct { + ID int64 `gorm:"primaryKey"` + UserID int64 `gorm:"index;not null"` + ApiKeyID int64 `gorm:"index;not null"` + AccountID int64 `gorm:"index;not null"` + RequestID string `gorm:"size:64"` + Model string `gorm:"size:100;index;not null"` + + GroupID *int64 `gorm:"index"` + SubscriptionID *int64 `gorm:"index"` + + InputTokens int `gorm:"default:0;not null"` + OutputTokens int `gorm:"default:0;not null"` + CacheCreationTokens int `gorm:"default:0;not null"` + CacheReadTokens int `gorm:"default:0;not null"` + + CacheCreation5mTokens int `gorm:"default:0;not null"` + CacheCreation1hTokens int `gorm:"default:0;not null"` + + InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"` + RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"` + + BillingType int8 `gorm:"type:smallint;default:0;not null"` + Stream bool `gorm:"default:false;not null"` + DurationMs *int + FirstTokenMs *int + + CreatedAt time.Time `gorm:"index;not null"` + + User *userModel `gorm:"foreignKey:UserID"` + ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"` + Account *accountModel `gorm:"foreignKey:AccountID"` + Group *groupModel `gorm:"foreignKey:GroupID"` + Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"` +} + +func (usageLogModel) TableName() string { return "usage_logs" } + +func usageLogModelToService(m *usageLogModel) *service.UsageLog { + if m == nil { + return nil + } + return &service.UsageLog{ + ID: m.ID, + UserID: m.UserID, + ApiKeyID: m.ApiKeyID, + AccountID: m.AccountID, + RequestID: m.RequestID, + Model: m.Model, + GroupID: m.GroupID, + SubscriptionID: m.SubscriptionID, + InputTokens: m.InputTokens, + OutputTokens: m.OutputTokens, + CacheCreationTokens: m.CacheCreationTokens, + CacheReadTokens: m.CacheReadTokens, + CacheCreation5mTokens: m.CacheCreation5mTokens, + CacheCreation1hTokens: m.CacheCreation1hTokens, + InputCost: m.InputCost, + OutputCost: m.OutputCost, + CacheCreationCost: m.CacheCreationCost, + CacheReadCost: m.CacheReadCost, + TotalCost: m.TotalCost, + ActualCost: m.ActualCost, + RateMultiplier: m.RateMultiplier, + BillingType: m.BillingType, + Stream: m.Stream, + DurationMs: m.DurationMs, + FirstTokenMs: m.FirstTokenMs, + CreatedAt: m.CreatedAt, + User: userModelToService(m.User), + ApiKey: apiKeyModelToService(m.ApiKey), + Account: accountModelToService(m.Account), + Group: groupModelToService(m.Group), + Subscription: userSubscriptionModelToService(m.Subscription), + } +} + +func usageLogModelsToService(models []usageLogModel) []service.UsageLog { + out := make([]service.UsageLog, 0, len(models)) + for i := range models { + if s := usageLogModelToService(&models[i]); s != nil { + out = append(out, *s) + } + } + return out +} + +func usageLogModelFromService(log *service.UsageLog) *usageLogModel { + if log == nil { + return nil + } + return &usageLogModel{ + ID: log.ID, + UserID: log.UserID, + ApiKeyID: log.ApiKeyID, + AccountID: log.AccountID, + RequestID: log.RequestID, + Model: log.Model, + GroupID: log.GroupID, + SubscriptionID: log.SubscriptionID, + InputTokens: log.InputTokens, + OutputTokens: log.OutputTokens, + CacheCreationTokens: log.CacheCreationTokens, + CacheReadTokens: log.CacheReadTokens, + CacheCreation5mTokens: log.CacheCreation5mTokens, + CacheCreation1hTokens: log.CacheCreation1hTokens, + InputCost: log.InputCost, + OutputCost: log.OutputCost, + CacheCreationCost: log.CacheCreationCost, + CacheReadCost: log.CacheReadCost, + TotalCost: log.TotalCost, + ActualCost: log.ActualCost, + RateMultiplier: log.RateMultiplier, + BillingType: log.BillingType, + Stream: log.Stream, + DurationMs: log.DurationMs, + FirstTokenMs: log.FirstTokenMs, + CreatedAt: log.CreatedAt, + } +} + +func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) { + if log == nil || m == nil { + return + } + log.ID = m.ID + log.CreatedAt = m.CreatedAt +} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 6423de71..4533a0ab 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) { suite.Run(t, new(UsageLogRepoSuite)) } -func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog { - log := &model.UsageLog{ +func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog { + log := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe // --- Create / GetByID --- func (s *UsageLogRepoSuite) TestCreate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"}) - log := &model.UsageLog{ + log := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() { } func (s *UsageLogRepoSuite) TestGetByID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"}) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() { // --- Delete --- func (s *UsageLogRepoSuite) TestDelete() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"}) log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() { // --- ListByUser --- func (s *UsageLogRepoSuite) TestListByUser() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) @@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() { // --- ListByApiKey --- func (s *UsageLogRepoSuite) TestListByApiKey() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now()) @@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() { // --- ListByAccount --- func (s *UsageLogRepoSuite) TestListByAccount() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() { // --- GetUserStats --- func (s *UsageLogRepoSuite) TestGetUserStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() { // --- ListWithFilters --- func (s *UsageLogRepoSuite) TestListWithFilters() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { now := time.Now() todayStart := timezone.Today() - userToday := mustCreateUser(s.T(), s.db, &model.User{ + userToday := mustCreateUser(s.T(), s.db, &userModel{ Email: "today@example.com", CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)), UpdatedAt: now, }) - userOld := mustCreateUser(s.T(), s.db, &model.User{ + userOld := mustCreateUser(s.T(), s.db, &userModel{ Email: "old@example.com", CreatedAt: todayStart.Add(-24 * time.Hour), UpdatedAt: todayStart.Add(-24 * time.Hour), }) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"}) - apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) - mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"}) + mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled}) resetAt := now.Add(10 * time.Minute) - accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) - mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) + accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true}) + mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true}) d1, d2, d3 := 100, 200, 300 - logToday := &model.UsageLog{ + logToday := &service.UsageLog{ UserID: userToday.ID, ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, @@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { } s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday") - logOld := &model.UsageLog{ + logOld := &service.UsageLog{ UserID: userOld.ID, ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, @@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { } s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld") - logPerf := &model.UsageLog{ + logPerf := &service.UsageLog{ UserID: userToday.ID, ApiKeyID: apiKey1.ID, AccountID: accNormal.ID, @@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { // --- GetUserDashboardStats --- func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { // --- GetAccountTodayStats --- func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { // --- GetBatchUserUsageStats --- func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"}) - apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) - apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"}) s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) @@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { // --- GetBatchApiKeyUsageStats --- func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"}) - apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) - apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"}) s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) @@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { // --- GetGlobalStats --- func (s *UsageLogRepoSuite) TestGetGlobalStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time { // --- ListByUserAndTimeRange --- func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() { // --- ListByApiKeyAndTimeRange --- func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() { // --- ListByAccountAndTimeRange --- func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() { // --- ListByModelAndTimeRange --- func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) // Create logs with different models - log1 := &model.UsageLog{ + log1 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { } s.Require().NoError(s.repo.Create(s.ctx, log1)) - log2 := &model.UsageLog{ + log2 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { } s.Require().NoError(s.repo.Create(s.ctx, log2)) - log3 := &model.UsageLog{ + log3 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() { // --- GetAccountWindowStats --- func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"}) now := time.Now() windowStart := now.Add(-10 * time.Minute) @@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() { // --- GetUserUsageTrendByUserID --- func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() { } func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() { // --- GetUserModelStats --- func (s *UsageLogRepoSuite) TestGetUserModelStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) // Create logs with different models - log1 := &model.UsageLog{ + log1 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { } s.Require().NoError(s.repo.Create(s.ctx, log1)) - log2 := &model.UsageLog{ + log2 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() { // --- GetUsageTrendWithFilters --- func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { } func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { // --- GetModelStatsWithFilters --- func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) - log1 := &model.UsageLog{ + log1 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { } s.Require().NoError(s.repo.Create(s.ctx, log1)) - log2 := &model.UsageLog{ + log2 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { // --- GetAccountUsageStats --- func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"}) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) // Create logs on different days - log1 := &model.UsageLog{ + log1 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { } s.Require().NoError(s.repo.Create(s.ctx, log1)) - log2 := &model.UsageLog{ + log2 := &service.UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() { } func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"}) base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC) startTime := base @@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() { // --- GetUserUsageTrend --- func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"}) - apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) - apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base) @@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { // --- GetApiKeyUsageTrend --- func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"}) - apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) - apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) @@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() { } func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) @@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() { // --- ListWithFilters (additional filter tests) --- func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"}) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now()) @@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { } func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) @@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() { } func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) - account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"}) + account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"}) base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base) diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index c87e3838..37e1e173 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -2,12 +2,13 @@ package repository import ( "context" + "time" "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/lib/pq" "gorm.io/gorm" ) @@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository { return &userRepository{db: db} } -func (r *userRepository) Create(ctx context.Context, user *model.User) error { - err := r.db.WithContext(ctx).Create(user).Error +func (r *userRepository) Create(ctx context.Context, user *service.User) error { + m := userModelFromService(user) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyUserModelToService(user, m) + } return translatePersistenceError(err, nil, service.ErrEmailExists) } -func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { - var user model.User - err := r.db.WithContext(ctx).First(&user, id).Error +func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) { + var m userModel + err := r.db.WithContext(ctx).First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } - return &user, nil + return userModelToService(&m), nil } -func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { - var user model.User - err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error +func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { + var m userModel + err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } - return &user, nil + return userModelToService(&m), nil } -func (r *userRepository) Update(ctx context.Context, user *model.User) error { - err := r.db.WithContext(ctx).Save(user).Error +func (r *userRepository) Update(ctx context.Context, user *service.User) error { + m := userModelFromService(user) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyUserModelToService(user, m) + } return translatePersistenceError(err, nil, service.ErrEmailExists) } func (r *userRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.User{}, id).Error + return r.db.WithContext(ctx).Delete(&userModel{}, id).Error } -func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { +func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists users with optional filtering by status, role, and search query -func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { - var users []model.User +func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { + var users []userModel var total int64 - db := r.db.WithContext(ctx).Model(&model.User{}) + db := r.db.WithContext(ctx).Model(&userModel{}) // Apply filters if status != "" { @@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. // Batch load subscriptions for all users (avoid N+1) if len(users) > 0 { userIDs := make([]int64, len(users)) - userMap := make(map[int64]*model.User, len(users)) + userMap := make(map[int64]*service.User, len(users)) + outUsers := make([]service.User, 0, len(users)) for i := range users { userIDs[i] = users[i].ID - userMap[users[i].ID] = &users[i] + u := userModelToService(&users[i]) + outUsers = append(outUsers, *u) + userMap[u.ID] = &outUsers[len(outUsers)-1] } // Query active subscriptions with groups in one query - var subscriptions []model.UserSubscription + var subscriptions []userSubscriptionModel if err := r.db.WithContext(ctx). Preload("Group"). - Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive). + Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive). Find(&subscriptions).Error; err != nil { return nil, nil, err } @@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. // Associate subscriptions with users for i := range subscriptions { if user, ok := userMap[subscriptions[i].UserID]; ok { - user.Subscriptions = append(user.Subscriptions, subscriptions[i]) + user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i])) } } + + return outUsers, paginationResultFromTotal(total, params), nil } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ + outUsers := make([]service.User, 0, len(users)) + for i := range users { + outUsers = append(outUsers, *userModelToService(&users[i])) } - return users, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return outUsers, paginationResultFromTotal(total, params), nil } func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { - return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id). Update("balance", gorm.Expr("balance + ?", amount)).Error } // DeductBalance 扣减用户余额,仅当余额充足时执行 func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { - result := r.db.WithContext(ctx).Model(&model.User{}). + result := r.db.WithContext(ctx).Model(&userModel{}). Where("id = ? AND balance >= ?", id, amount). Update("balance", gorm.Expr("balance - ?", amount)) if result.Error != nil { @@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo } func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { - return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). + return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id). Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error + err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error return count > 0, err } // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // 使用 PostgreSQL 的 array_remove 函数 func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { - result := r.db.WithContext(ctx).Model(&model.User{}). + result := r.db.WithContext(ctx).Model(&userModel{}). Where("? = ANY(allowed_groups)", groupID). Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) return result.RowsAffected, result.Error } // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) -func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { - var user model.User +func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) { + var m userModel err := r.db.WithContext(ctx). - Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). + Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive). Order("id ASC"). - First(&user).Error + First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } - return &user, nil + return userModelToService(&m), nil +} + +type userModel struct { + ID int64 `gorm:"primaryKey"` + Email string `gorm:"uniqueIndex;size:255;not null"` + Username string `gorm:"size:100;default:''"` + Wechat string `gorm:"size:100;default:''"` + Notes string `gorm:"type:text;default:''"` + PasswordHash string `gorm:"size:255;not null"` + Role string `gorm:"size:20;default:user;not null"` + Balance float64 `gorm:"type:decimal(20,8);default:0;not null"` + Concurrency int `gorm:"default:5;not null"` + Status string `gorm:"size:20;default:active;not null"` + AllowedGroups pq.Int64Array `gorm:"type:bigint[]"` + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +func (userModel) TableName() string { return "users" } + +func userModelToService(m *userModel) *service.User { + if m == nil { + return nil + } + return &service.User{ + ID: m.ID, + Email: m.Email, + Username: m.Username, + Wechat: m.Wechat, + Notes: m.Notes, + PasswordHash: m.PasswordHash, + Role: m.Role, + Balance: m.Balance, + Concurrency: m.Concurrency, + Status: m.Status, + AllowedGroups: []int64(m.AllowedGroups), + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } +} + +func userModelFromService(u *service.User) *userModel { + if u == nil { + return nil + } + return &userModel{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + Wechat: u.Wechat, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + AllowedGroups: pq.Int64Array(u.AllowedGroups), + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + } +} + +func applyUserModelToService(dst *service.User, src *userModel) { + if dst == nil || src == nil { + return + } + dst.ID = src.ID + dst.CreatedAt = src.CreatedAt + dst.UpdatedAt = src.UpdatedAt } diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index 020e2e32..cd5254ee 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) { // --- Create / GetByID / GetByEmail / Update / Delete --- func (s *UserRepoSuite) TestCreate() { - user := &model.User{ - Email: "create@test.com", - Username: "testuser", - Role: model.RoleUser, - Status: model.StatusActive, + user := &service.User{ + Email: "create@test.com", + Username: "testuser", + PasswordHash: "test-password-hash", + Role: service.RoleUser, + Status: service.StatusActive, } err := s.repo.Create(s.ctx, user) @@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() { } func (s *UserRepoSuite) TestGetByEmail() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"}) got, err := s.repo.GetByEmail(s.ctx, user.Email) s.Require().NoError(err, "GetByEmail") @@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() { } func (s *UserRepoSuite) TestUpdate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"}) + user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"})) user.Username = "updated" err := s.repo.Update(s.ctx, user) @@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() { } func (s *UserRepoSuite) TestDelete() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"}) err := s.repo.Delete(s.ctx, user.ID) s.Require().NoError(err, "Delete") @@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() { // --- List / ListWithFilters --- func (s *UserRepoSuite) TestList() { - mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"}) - mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"}) users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") @@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() { } func (s *UserRepoSuite) TestListWithFilters_Status() { - mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive}) - mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled}) + mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive}) + mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled}) - users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "") + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "") s.Require().NoError(err) s.Require().Len(users, 1) - s.Require().Equal(model.StatusActive, users[0].Status) + s.Require().Equal(service.StatusActive, users[0].Status) } func (s *UserRepoSuite) TestListWithFilters_Role() { - mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser}) - mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) + mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser}) + mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin}) - users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "") + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "") s.Require().NoError(err) s.Require().Len(users, 1) - s.Require().Equal(model.RoleAdmin, users[0].Role) + s.Require().Equal(service.RoleAdmin, users[0].Role) } func (s *UserRepoSuite) TestListWithFilters_Search() { - mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"}) - mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"}) users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice") s.Require().NoError(err) @@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() { } func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { - mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"}) - mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"}) users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john") s.Require().NoError(err) @@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() { } func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { - mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"}) - mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"}) users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello") s.Require().NoError(err) @@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() { } func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"}) - _ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(1 * time.Hour), }) - _ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + _ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-1 * time.Hour), }) @@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() { } func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "a@example.com", Username: "Alice", Wechat: "wx_a", - Role: model.RoleUser, - Status: model.StatusActive, + Role: service.RoleUser, + Status: service.StatusActive, Balance: 10, }) - target := mustCreateUser(s.T(), s.db, &model.User{ + target := mustCreateUser(s.T(), s.db, &userModel{ Email: "b@example.com", Username: "Bob", Wechat: "wx_b", - Role: model.RoleAdmin, - Status: model.StatusActive, + Role: service.RoleAdmin, + Status: service.StatusActive, Balance: 1, }) - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "c@example.com", - Role: model.RoleAdmin, - Status: model.StatusDisabled, + Role: service.RoleAdmin, + Status: service.StatusDisabled, }) - users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@") + users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch") @@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() { // --- Balance operations --- func (s *UserRepoSuite) TestUpdateBalance() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10}) err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5) s.Require().NoError(err, "UpdateBalance") @@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() { } func (s *UserRepoSuite) TestUpdateBalance_Negative() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10}) err := s.repo.UpdateBalance(s.ctx, user.ID, -3) s.Require().NoError(err, "UpdateBalance with negative") @@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() { } func (s *UserRepoSuite) TestDeductBalance() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10}) err := s.repo.DeductBalance(s.ctx, user.ID, 5) s.Require().NoError(err, "DeductBalance") @@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() { } func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5}) err := s.repo.DeductBalance(s.ctx, user.ID, 999) s.Require().Error(err, "expected error for insufficient balance") @@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { } func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10}) err := s.repo.DeductBalance(s.ctx, user.ID, 10) s.Require().NoError(err, "DeductBalance exact amount") @@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { // --- Concurrency --- func (s *UserRepoSuite) TestUpdateConcurrency() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5}) err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3) s.Require().NoError(err, "UpdateConcurrency") @@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() { } func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5}) err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2) s.Require().NoError(err, "UpdateConcurrency negative") @@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() { // --- ExistsByEmail --- func (s *UserRepoSuite) TestExistsByEmail() { - mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) + mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"}) exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com") s.Require().NoError(err, "ExistsByEmail") @@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() { func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { groupID := int64(42) - userA := mustCreateUser(s.T(), s.db, &model.User{ + userA := mustCreateUser(s.T(), s.db, &userModel{ Email: "a1@example.com", AllowedGroups: pq.Int64Array{groupID, 7}, }) - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "a2@example.com", AllowedGroups: pq.Int64Array{7}, }) @@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() { } func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "nomatch@test.com", AllowedGroups: pq.Int64Array{1, 2, 3}, }) @@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() { // --- GetFirstAdmin --- func (s *UserRepoSuite) TestGetFirstAdmin() { - admin1 := mustCreateUser(s.T(), s.db, &model.User{ + admin1 := mustCreateUser(s.T(), s.db, &userModel{ Email: "admin1@example.com", - Role: model.RoleAdmin, - Status: model.StatusActive, + Role: service.RoleAdmin, + Status: service.StatusActive, }) - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "admin2@example.com", - Role: model.RoleAdmin, - Status: model.StatusActive, + Role: service.RoleAdmin, + Status: service.StatusActive, }) got, err := s.repo.GetFirstAdmin(s.ctx) @@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() { } func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "user@example.com", - Role: model.RoleUser, - Status: model.StatusActive, + Role: service.RoleUser, + Status: service.StatusActive, }) _, err := s.repo.GetFirstAdmin(s.ctx) @@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() { } func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { - mustCreateUser(s.T(), s.db, &model.User{ + mustCreateUser(s.T(), s.db, &userModel{ Email: "disabled@example.com", - Role: model.RoleAdmin, - Status: model.StatusDisabled, + Role: service.RoleAdmin, + Status: service.StatusDisabled, }) - activeAdmin := mustCreateUser(s.T(), s.db, &model.User{ + activeAdmin := mustCreateUser(s.T(), s.db, &userModel{ Email: "active@example.com", - Role: model.RoleAdmin, - Status: model.StatusActive, + Role: service.RoleAdmin, + Status: service.StatusActive, }) got, err := s.repo.GetFirstAdmin(s.ctx) @@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() { // --- Combined original test --- func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { - user1 := mustCreateUser(s.T(), s.db, &model.User{ + user1 := mustCreateUser(s.T(), s.db, &userModel{ Email: "a@example.com", Username: "Alice", Wechat: "wx_a", - Role: model.RoleUser, - Status: model.StatusActive, + Role: service.RoleUser, + Status: service.StatusActive, Balance: 10, }) - user2 := mustCreateUser(s.T(), s.db, &model.User{ + user2 := mustCreateUser(s.T(), s.db, &userModel{ Email: "b@example.com", Username: "Bob", Wechat: "wx_b", - Role: model.RoleAdmin, - Status: model.StatusActive, + Role: service.RoleAdmin, + Status: service.StatusActive, Balance: 1, }) - _ = mustCreateUser(s.T(), s.db, &model.User{ + _ = mustCreateUser(s.T(), s.db, &userModel{ Email: "c@example.com", - Role: model.RoleAdmin, - Status: model.StatusDisabled, + Role: service.RoleAdmin, + Status: service.StatusDisabled, }) got, err := s.repo.GetByID(s.ctx, user1.ID) @@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch") params := pagination.PaginationParams{Page: 1, PageSize: 10} - users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@") + users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@") s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch") s.Require().Len(users, 1, "ListWithFilters len mismatch") diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 7fea8fb0..4c7768a8 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -4,111 +4,113 @@ import ( "context" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "gorm.io/gorm" ) -// UserSubscriptionRepository 用户订阅仓库 type userSubscriptionRepository struct { db *gorm.DB } -// NewUserSubscriptionRepository 创建用户订阅仓库 func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository { return &userSubscriptionRepository{db: db} } -// Create 创建订阅 -func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { - err := r.db.WithContext(ctx).Create(sub).Error +func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error { + m := userSubscriptionModelFromService(sub) + err := r.db.WithContext(ctx).Create(m).Error + if err == nil { + applyUserSubscriptionModelToService(sub, m) + } return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists) } -// GetByID 根据ID获取订阅 -func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { - var sub model.UserSubscription +func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + var m userSubscriptionModel err := r.db.WithContext(ctx). Preload("User"). Preload("Group"). Preload("AssignedByUser"). - First(&sub, id).Error + First(&m, id).Error if err != nil { return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } - return &sub, nil + return userSubscriptionModelToService(&m), nil } -// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 -func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { - var sub model.UserSubscription +func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + var m userSubscriptionModel err := r.db.WithContext(ctx). Preload("Group"). Where("user_id = ? AND group_id = ?", userID, groupID). - First(&sub).Error + First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } - return &sub, nil + return userSubscriptionModelToService(&m), nil } -// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 -func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { - var sub model.UserSubscription +func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + var m userSubscriptionModel err := r.db.WithContext(ctx). Preload("Group"). Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?", - userID, groupID, model.SubscriptionStatusActive, time.Now()). - First(&sub).Error + userID, groupID, service.SubscriptionStatusActive, time.Now()). + First(&m).Error if err != nil { return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } - return &sub, nil + return userSubscriptionModelToService(&m), nil } -// Update 更新订阅 -func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { +func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error { sub.UpdatedAt = time.Now() - return r.db.WithContext(ctx).Save(sub).Error + m := userSubscriptionModelFromService(sub) + err := r.db.WithContext(ctx).Save(m).Error + if err == nil { + applyUserSubscriptionModelToService(sub, m) + } + return err } -// Delete 删除订阅 func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { - return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error + return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error } -// ListByUserID 获取用户的所有订阅 -func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { - var subs []model.UserSubscription +func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + var subs []userSubscriptionModel err := r.db.WithContext(ctx). Preload("Group"). Where("user_id = ?", userID). Order("created_at DESC"). Find(&subs).Error - return subs, err + if err != nil { + return nil, err + } + return userSubscriptionModelsToService(subs), nil } -// ListActiveByUserID 获取用户的所有有效订阅 -func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { - var subs []model.UserSubscription +func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + var subs []userSubscriptionModel err := r.db.WithContext(ctx). Preload("Group"). Where("user_id = ? AND status = ? AND expires_at > ?", - userID, model.SubscriptionStatusActive, time.Now()). + userID, service.SubscriptionStatusActive, time.Now()). Order("created_at DESC"). Find(&subs).Error - return subs, err + if err != nil { + return nil, err + } + return userSubscriptionModelsToService(subs), nil } -// ListByGroupID 获取分组的所有订阅(分页) -func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { - var subs []model.UserSubscription +func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + var subs []userSubscriptionModel var total int64 - query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID) - + query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID) if err := query.Count(&total).Error; err != nil { return nil, nil, err } @@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return subs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil } -// List 获取所有订阅(分页,支持筛选) -func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { - var subs []model.UserSubscription +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + var subs []userSubscriptionModel var total int64 - query := r.db.WithContext(ctx).Model(&model.UserSubscription{}) - + query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}) if userID != nil { query = query.Where("user_id = ?", *userID) } @@ -170,22 +160,87 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination return nil, nil, err } - pages := int(total) / params.Limit() - if int(total)%params.Limit() > 0 { - pages++ - } - - return subs, &pagination.PaginationResult{ - Total: total, - Page: params.Page, - PageSize: params.Limit(), - Pages: pages, - }, nil + return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil +} + +func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + var count int64 + err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("user_id = ? AND group_id = ?", userID, groupID). + Count(&count).Error + return count > 0, err +} + +func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", subscriptionID). + Updates(map[string]any{ + "expires_at": newExpiresAt, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", subscriptionID). + Updates(map[string]any{ + "status": status, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", subscriptionID). + Updates(map[string]any{ + "notes": notes, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", id). + Updates(map[string]any{ + "daily_window_start": start, + "weekly_window_start": start, + "monthly_window_start": start, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", id). + Updates(map[string]any{ + "daily_usage_usd": 0, + "daily_window_start": newWindowStart, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", id). + Updates(map[string]any{ + "weekly_usage_usd": 0, + "weekly_window_start": newWindowStart, + "updated_at": time.Now(), + }).Error +} + +func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("id = ?", id). + Updates(map[string]any{ + "monthly_usage_usd": 0, + "monthly_window_start": newWindowStart, + "updated_at": time.Now(), + }).Error } -// IncrementUsage 增加使用量 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). + return r.db.WithContext(ctx).Model(&userSubscriptionModel{}). Where("id = ?", id). Updates(map[string]any{ "daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD), @@ -195,131 +250,150 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 }).Error } -// ResetDailyUsage 重置日使用量 -func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "daily_usage_usd": 0, - "daily_window_start": newWindowStart, - "updated_at": time.Now(), - }).Error -} - -// ResetWeeklyUsage 重置周使用量 -func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "weekly_usage_usd": 0, - "weekly_window_start": newWindowStart, - "updated_at": time.Now(), - }).Error -} - -// ResetMonthlyUsage 重置月使用量 -func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "monthly_usage_usd": 0, - "monthly_window_start": newWindowStart, - "updated_at": time.Now(), - }).Error -} - -// ActivateWindows 激活所有窗口(首次使用时) -func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "daily_window_start": activateTime, - "weekly_window_start": activateTime, - "monthly_window_start": activateTime, - "updated_at": time.Now(), - }).Error -} - -// UpdateStatus 更新订阅状态 -func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "status": status, - "updated_at": time.Now(), - }).Error -} - -// ExtendExpiry 延长订阅过期时间 -func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "expires_at": newExpiresAt, - "updated_at": time.Now(), - }).Error -} - -// UpdateNotes 更新订阅备注 -func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { - return r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("id = ?", id). - Updates(map[string]any{ - "notes": notes, - "updated_at": time.Now(), - }).Error -} - -// ListExpired 获取所有已过期但状态仍为active的订阅 -func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { - var subs []model.UserSubscription - err := r.db.WithContext(ctx). - Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). - Find(&subs).Error - return subs, err -} - -// BatchUpdateExpiredStatus 批量更新过期订阅状态 func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { - result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). + result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}). + Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()). Updates(map[string]any{ - "status": model.SubscriptionStatusExpired, + "status": service.SubscriptionStatusExpired, "updated_at": time.Now(), }) return result.RowsAffected, result.Error } -// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 -func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { - var count int64 - err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). - Where("user_id = ? AND group_id = ?", userID, groupID). - Count(&count).Error - return count > 0, err +// Extra repository helpers (currently used only by integration tests). + +func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) { + var subs []userSubscriptionModel + err := r.db.WithContext(ctx). + Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()). + Find(&subs).Error + if err != nil { + return nil, err + } + return userSubscriptionModelsToService(subs), nil } -// CountByGroupID 获取分组的订阅数量 func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). + err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}). Where("group_id = ?", groupID). Count(&count).Error return count, err } -// CountActiveByGroupID 获取分组的有效订阅数量 func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 - err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). + err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}). Where("group_id = ? AND status = ? AND expires_at > ?", - groupID, model.SubscriptionStatusActive, time.Now()). + groupID, service.SubscriptionStatusActive, time.Now()). Count(&count).Error return count, err } -// DeleteByGroupID 删除分组相关的所有订阅记录 func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { - result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) + result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{}) return result.RowsAffected, result.Error } + +type userSubscriptionModel struct { + ID int64 `gorm:"primaryKey"` + UserID int64 `gorm:"index;not null"` + GroupID int64 `gorm:"index;not null"` + + StartsAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"not null"` + Status string `gorm:"size:20;default:active;not null"` + + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time + + DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"` + WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"` + MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"` + + AssignedBy *int64 `gorm:"index"` + AssignedAt time.Time `gorm:"not null"` + Notes string `gorm:"type:text"` + + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + + User *userModel `gorm:"foreignKey:UserID"` + Group *groupModel `gorm:"foreignKey:GroupID"` + AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"` +} + +func (userSubscriptionModel) TableName() string { return "user_subscriptions" } + +func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription { + if m == nil { + return nil + } + return &service.UserSubscription{ + ID: m.ID, + UserID: m.UserID, + GroupID: m.GroupID, + StartsAt: m.StartsAt, + ExpiresAt: m.ExpiresAt, + Status: m.Status, + DailyWindowStart: m.DailyWindowStart, + WeeklyWindowStart: m.WeeklyWindowStart, + MonthlyWindowStart: m.MonthlyWindowStart, + DailyUsageUSD: m.DailyUsageUSD, + WeeklyUsageUSD: m.WeeklyUsageUSD, + MonthlyUsageUSD: m.MonthlyUsageUSD, + AssignedBy: m.AssignedBy, + AssignedAt: m.AssignedAt, + Notes: m.Notes, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + User: userModelToService(m.User), + Group: groupModelToService(m.Group), + AssignedByUser: userModelToService(m.AssignedByUser), + } +} + +func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription { + out := make([]service.UserSubscription, 0, len(models)) + for i := range models { + if s := userSubscriptionModelToService(&models[i]); s != nil { + out = append(out, *s) + } + } + return out +} + +func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel { + if s == nil { + return nil + } + return &userSubscriptionModel{ + ID: s.ID, + UserID: s.UserID, + GroupID: s.GroupID, + StartsAt: s.StartsAt, + ExpiresAt: s.ExpiresAt, + Status: s.Status, + DailyWindowStart: s.DailyWindowStart, + WeeklyWindowStart: s.WeeklyWindowStart, + MonthlyWindowStart: s.MonthlyWindowStart, + DailyUsageUSD: s.DailyUsageUSD, + WeeklyUsageUSD: s.WeeklyUsageUSD, + MonthlyUsageUSD: s.MonthlyUsageUSD, + AssignedBy: s.AssignedBy, + AssignedAt: s.AssignedAt, + Notes: s.Notes, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + } +} + +func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) { + if sub == nil || m == nil { + return + } + sub.ID = m.ID + sub.CreatedAt = m.CreatedAt + sub.UpdatedAt = m.UpdatedAt +} diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index e6c1c850..f990b802 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) { // --- Create / GetByID / Update / Delete --- func (s *UserSubscriptionRepoSuite) TestCreate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"}) - sub := &model.UserSubscription{ + sub := &service.UserSubscription{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), } @@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() { } func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"}) - admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"}) + admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), AssignedBy: &admin.ID, }) @@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() { } func (s *UserSubscriptionRepoSuite) TestUpdate() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"}) + sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), - }) + })) sub.Notes = "updated notes" err := s.repo.Update(s.ctx, sub) @@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() { } func (s *UserSubscriptionRepoSuite) TestDelete() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() { // --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID --- func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() { } func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"}) // Create active subscription (future expiry) - active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(2 * time.Hour), }) @@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() { } func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"}) // Create expired subscription (past expiry but active status) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(-2 * time.Hour), }) @@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor // --- ListByUserID / ListActiveByUserID --- func (s *UserSubscriptionRepoSuite) TestListByUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"}) - g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"}) - g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"}) + g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g1.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g2.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-24 * time.Hour), }) @@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() { } func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"}) - g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"}) - g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"}) + g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g1.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g2.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-24 * time.Hour), }) subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID) s.Require().NoError(err, "ListActiveByUserID") s.Require().Len(subs, 1) - s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status) + s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status) } // --- ListByGroupID --- func (s *UserSubscriptionRepoSuite) TestListByGroupID() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user1.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user2.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() { // --- List with filters --- func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { } func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user1.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user2.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { } func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"}) - g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"}) - g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"}) + g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"}) + g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g1.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: g2.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { } func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-24 * time.Hour), }) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired) + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired) s.Require().NoError(err) s.Require().Len(subs, 1) - s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status) + s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) } // --- Usage tracking --- func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() { } func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() { } func (s *UserSubscriptionRepoSuite) TestActivateWindows() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() { } func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), DailyUsageUSD: 10.0, WeeklyUsageUSD: 20.0, @@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() { } func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), WeeklyUsageUSD: 15.0, MonthlyUsageUSD: 30.0, @@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() { } func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), MonthlyUsageUSD: 100.0, }) @@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() { // --- UpdateStatus / ExtendExpiry / UpdateNotes --- func (s *UserSubscriptionRepoSuite) TestUpdateStatus() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired) + err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired) s.Require().NoError(err, "UpdateStatus") got, err := s.repo.GetByID(s.ctx, sub.ID) s.Require().NoError(err) - s.Require().Equal(model.SubscriptionStatusExpired, got.Status) + s.Require().Equal(service.SubscriptionStatusExpired, got.Status) } func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() { } func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"}) - sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"}) + sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() { // --- ListExpired / BatchUpdateExpiredStatus --- func (s *UserSubscriptionRepoSuite) TestListExpired() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(-24 * time.Hour), }) @@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() { } func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"}) - active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(-24 * time.Hour), }) @@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() { s.Require().Equal(int64(1), affected) gotActive, _ := s.repo.GetByID(s.ctx, active.ID) - s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status) + s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status) gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID) - s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status) + s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status) } // --- ExistsByUserIDAndGroupID --- func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) @@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() { // --- CountByGroupID / CountActiveByGroupID --- func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user1.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user2.ID, GroupID: group.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-24 * time.Hour), }) @@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() { } func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { - user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"}) - user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"}) + user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"}) + user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user1.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user2.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time }) @@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() { // --- DeleteByGroupID --- func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"}) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(24 * time.Hour), }) - mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusExpired, + Status: service.SubscriptionStatusExpired, ExpiresAt: time.Now().Add(-24 * time.Hour), }) @@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() { // --- Combined original test --- func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() { - user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"}) - group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"}) + user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"}) + group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"}) - active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(2 * time.Hour), }) - expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{ + expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{ UserID: user.ID, GroupID: group.ID, - Status: model.SubscriptionStatusActive, + Status: service.SubscriptionStatusActive, ExpiresAt: time.Now().Add(-2 * time.Hour), }) @@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().Equal(int64(1), affected, "expected 1 affected row") updated, err := s.repo.GetByID(s.ctx, expiredActive.ID) s.Require().NoError(err, "GetByID expired") - s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired") + s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go new file mode 100644 index 00000000..8833a4e6 --- /dev/null +++ b/backend/internal/server/api_contract_test.go @@ -0,0 +1,1039 @@ +//go:build unit + +package server_test + +import ( + "bytes" + "context" + "errors" + "io" + "math" + "net/http" + "net/http/httptest" + "sort" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + adminhandler "github.com/Wei-Shaw/sub2api/internal/handler/admin" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAPIContracts(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + setup func(t *testing.T, deps *contractDeps) + method string + path string + body string + headers map[string]string + wantStatus int + wantJSON string + }{ + { + name: "GET /api/v1/auth/me", + method: http.MethodGet, + path: "/api/v1/auth/me", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "id": 1, + "email": "alice@example.com", + "username": "alice", + "wechat": "wx_alice", + "notes": "hello", + "role": "user", + "balance": 12.5, + "concurrency": 5, + "status": "active", + "allowed_groups": null, + "created_at": "2025-01-02T03:04:05Z", + "updated_at": "2025-01-02T03:04:05Z" + } + }`, + }, + { + name: "POST /api/v1/keys", + method: http.MethodPost, + path: "/api/v1/keys", + body: `{"name":"Key One","custom_key":"sk_custom_1234567890"}`, + headers: map[string]string{ + "Content-Type": "application/json", + }, + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "id": 100, + "user_id": 1, + "key": "sk_custom_1234567890", + "name": "Key One", + "group_id": null, + "status": "active", + "created_at": "2025-01-02T03:04:05Z", + "updated_at": "2025-01-02T03:04:05Z" + } + }`, + }, + { + name: "GET /api/v1/keys (paginated)", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.apiKeyRepo.MustSeed(&service.ApiKey{ + ID: 100, + UserID: 1, + Key: "sk_custom_1234567890", + Name: "Key One", + Status: service.StatusActive, + CreatedAt: deps.now, + UpdatedAt: deps.now, + }) + }, + method: http.MethodGet, + path: "/api/v1/keys?page=1&page_size=10", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "items": [ + { + "id": 100, + "user_id": 1, + "key": "sk_custom_1234567890", + "name": "Key One", + "group_id": null, + "status": "active", + "created_at": "2025-01-02T03:04:05Z", + "updated_at": "2025-01-02T03:04:05Z" + } + ], + "total": 1, + "page": 1, + "page_size": 10, + "pages": 1 + } + }`, + }, + { + name: "GET /api/v1/usage/stats", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.usageRepo.SetUserLogs(1, []service.UsageLog{ + { + ID: 1, + UserID: 1, + ApiKeyID: 100, + AccountID: 200, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 1, + CacheReadTokens: 2, + TotalCost: 0.5, + ActualCost: 0.5, + DurationMs: ptr(100), + CreatedAt: deps.now, + }, + { + ID: 2, + UserID: 1, + ApiKeyID: 100, + AccountID: 200, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 15, + TotalCost: 0.25, + ActualCost: 0.25, + DurationMs: ptr(300), + CreatedAt: deps.now, + }, + }) + }, + method: http.MethodGet, + path: "/api/v1/usage/stats?start_date=2025-01-01&end_date=2025-01-02", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "total_requests": 2, + "total_input_tokens": 15, + "total_output_tokens": 35, + "total_cache_tokens": 3, + "total_tokens": 53, + "total_cost": 0.75, + "total_actual_cost": 0.75, + "average_duration_ms": 200 + } + }`, + }, + { + name: "GET /api/v1/usage (paginated)", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.usageRepo.SetUserLogs(1, []service.UsageLog{ + { + ID: 1, + UserID: 1, + ApiKeyID: 100, + AccountID: 200, + RequestID: "req_123", + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 1, + CacheReadTokens: 2, + TotalCost: 0.5, + ActualCost: 0.5, + RateMultiplier: 1, + BillingType: service.BillingTypeBalance, + Stream: true, + DurationMs: ptr(100), + FirstTokenMs: ptr(50), + CreatedAt: deps.now, + }, + }) + }, + method: http.MethodGet, + path: "/api/v1/usage?page=1&page_size=10", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "items": [ + { + "id": 1, + "user_id": 1, + "api_key_id": 100, + "account_id": 200, + "request_id": "req_123", + "model": "claude-3", + "group_id": null, + "subscription_id": null, + "input_tokens": 10, + "output_tokens": 20, + "cache_creation_tokens": 1, + "cache_read_tokens": 2, + "cache_creation_5m_tokens": 0, + "cache_creation_1h_tokens": 0, + "input_cost": 0, + "output_cost": 0, + "cache_creation_cost": 0, + "cache_read_cost": 0, + "total_cost": 0.5, + "actual_cost": 0.5, + "rate_multiplier": 1, + "billing_type": 0, + "stream": true, + "duration_ms": 100, + "first_token_ms": 50, + "created_at": "2025-01-02T03:04:05Z" + } + ], + "total": 1, + "page": 1, + "page_size": 10, + "pages": 1 + } + }`, + }, + { + name: "GET /api/v1/admin/settings", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + deps.settingRepo.SetAll(map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + + service.SettingKeySmtpHost: "smtp.example.com", + service.SettingKeySmtpPort: "587", + service.SettingKeySmtpUsername: "user", + service.SettingKeySmtpPassword: "secret", + service.SettingKeySmtpFrom: "no-reply@example.com", + service.SettingKeySmtpFromName: "Sub2API", + service.SettingKeySmtpUseTLS: "true", + + service.SettingKeyTurnstileEnabled: "true", + service.SettingKeyTurnstileSiteKey: "site-key", + service.SettingKeyTurnstileSecretKey: "secret-key", + + service.SettingKeySiteName: "Sub2API", + service.SettingKeySiteLogo: "", + service.SettingKeySiteSubtitle: "Subtitle", + service.SettingKeyApiBaseUrl: "https://api.example.com", + service.SettingKeyContactInfo: "support", + service.SettingKeyDocUrl: "https://docs.example.com", + + service.SettingKeyDefaultConcurrency: "5", + service.SettingKeyDefaultBalance: "1.25", + }) + }, + method: http.MethodGet, + path: "/api/v1/admin/settings", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "registration_enabled": true, + "email_verify_enabled": false, + "smtp_host": "smtp.example.com", + "smtp_port": 587, + "smtp_username": "user", + "smtp_password": "secret", + "smtp_from_email": "no-reply@example.com", + "smtp_from_name": "Sub2API", + "smtp_use_tls": true, + "turnstile_enabled": true, + "turnstile_site_key": "site-key", + "turnstile_secret_key": "secret-key", + "site_name": "Sub2API", + "site_logo": "", + "site_subtitle": "Subtitle", + "api_base_url": "https://api.example.com", + "contact_info": "support", + "doc_url": "https://docs.example.com", + "default_concurrency": 5, + "default_balance": 1.25 + } + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deps := newContractDeps(t) + if tt.setup != nil { + tt.setup(t, deps) + } + + status, body := doRequest(t, deps.router, tt.method, tt.path, tt.body, tt.headers) + require.Equal(t, tt.wantStatus, status) + require.JSONEq(t, tt.wantJSON, body) + }) + } +} + +type contractDeps struct { + now time.Time + router http.Handler + apiKeyRepo *stubApiKeyRepo + usageRepo *stubUsageLogRepo + settingRepo *stubSettingRepo +} + +func newContractDeps(t *testing.T) *contractDeps { + t.Helper() + + now := time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC) + + userRepo := &stubUserRepo{ + users: map[int64]*service.User{ + 1: { + ID: 1, + Email: "alice@example.com", + Username: "alice", + Wechat: "wx_alice", + Notes: "hello", + Role: service.RoleUser, + Balance: 12.5, + Concurrency: 5, + Status: service.StatusActive, + AllowedGroups: nil, + CreatedAt: now, + UpdatedAt: now, + }, + }, + } + + apiKeyRepo := newStubApiKeyRepo(now) + apiKeyCache := stubApiKeyCache{} + groupRepo := stubGroupRepo{} + userSubRepo := stubUserSubscriptionRepo{} + + cfg := &config.Config{ + Default: config.DefaultConfig{ + ApiKeyPrefix: "sk-", + }, + } + + userService := service.NewUserService(userRepo) + apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) + + usageRepo := newStubUsageLogRepo() + usageService := service.NewUsageService(usageRepo, userRepo) + + settingRepo := newStubSettingRepo() + settingService := service.NewSettingService(settingRepo, cfg) + + authHandler := handler.NewAuthHandler(nil, userService) + apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) + usageHandler := handler.NewUsageHandler(usageService, apiKeyService) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil) + + jwtAuth := func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 5, + }) + c.Set(string(middleware.ContextKeyUserRole), service.RoleUser) + c.Next() + } + adminAuth := func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 5, + }) + c.Set(string(middleware.ContextKeyUserRole), service.RoleAdmin) + c.Next() + } + + r := gin.New() + + v1 := r.Group("/api/v1") + + v1Auth := v1.Group("") + v1Auth.Use(jwtAuth) + v1Auth.GET("/auth/me", authHandler.GetCurrentUser) + + v1Keys := v1.Group("") + v1Keys.Use(jwtAuth) + v1Keys.GET("/keys", apiKeyHandler.List) + v1Keys.POST("/keys", apiKeyHandler.Create) + + v1Usage := v1.Group("") + v1Usage.Use(jwtAuth) + v1Usage.GET("/usage", usageHandler.List) + v1Usage.GET("/usage/stats", usageHandler.Stats) + + v1Admin := v1.Group("/admin") + v1Admin.Use(adminAuth) + v1Admin.GET("/settings", adminSettingHandler.GetSettings) + + return &contractDeps{ + now: now, + router: r, + apiKeyRepo: apiKeyRepo, + usageRepo: usageRepo, + settingRepo: settingRepo, + } +} + +func doRequest(t *testing.T, router http.Handler, method, path, body string, headers map[string]string) (int, string) { + t.Helper() + + req := httptest.NewRequest(method, path, bytes.NewBufferString(body)) + for k, v := range headers { + req.Header.Set(k, v) + } + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + respBody, err := io.ReadAll(w.Result().Body) + require.NoError(t, err) + + return w.Result().StatusCode, string(respBody) +} + +func ptr[T any](v T) *T { return &v } + +type stubUserRepo struct { + users map[int64]*service.User +} + +func (r *stubUserRepo) Create(ctx context.Context, user *service.User) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + user, ok := r.users[id] + if !ok { + return nil, service.ErrUserNotFound + } + clone := *user + return &clone, nil +} + +func (r *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + for _, user := range r.users { + if user.Email == email { + clone := *user + return &clone, nil + } + } + return nil, service.ErrUserNotFound +} + +func (r *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { + for _, user := range r.users { + if user.Role == service.RoleAdmin && user.Status == service.StatusActive { + clone := *user + return &clone, nil + } + } + return nil, service.ErrUserNotFound +} + +func (r *stubUserRepo) Update(ctx context.Context, user *service.User) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + return false, errors.New("not implemented") +} + +func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubApiKeyCache struct{} + +func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (stubApiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (stubApiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (stubApiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return nil +} + +func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return nil +} + +type stubGroupRepo struct{} + +func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { + return errors.New("not implemented") +} + +func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, error) { + return nil, service.ErrGroupNotFound +} + +func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error { + return errors.New("not implemented") +} + +func (stubGroupRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (stubGroupRepo) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, errors.New("not implemented") +} + +func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { + return nil, errors.New("not implemented") +} + +func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + return nil, errors.New("not implemented") +} + +func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, errors.New("not implemented") +} + +func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubUserSubscriptionRepo struct{} + +func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} +func (stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubApiKeyRepo struct { + now time.Time + + nextID int64 + byID map[int64]*service.ApiKey + byKey map[string]*service.ApiKey +} + +func newStubApiKeyRepo(now time.Time) *stubApiKeyRepo { + return &stubApiKeyRepo{ + now: now, + nextID: 100, + byID: make(map[int64]*service.ApiKey), + byKey: make(map[string]*service.ApiKey), + } +} + +func (r *stubApiKeyRepo) MustSeed(key *service.ApiKey) { + if key == nil { + return + } + clone := *key + r.byID[clone.ID] = &clone + r.byKey[clone.Key] = &clone +} + +func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { + if key == nil { + return errors.New("nil key") + } + if key.ID == 0 { + key.ID = r.nextID + r.nextID++ + } + if key.CreatedAt.IsZero() { + key.CreatedAt = r.now + } + if key.UpdatedAt.IsZero() { + key.UpdatedAt = r.now + } + clone := *key + r.byID[clone.ID] = &clone + r.byKey[clone.Key] = &clone + return nil +} + +func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { + key, ok := r.byID[id] + if !ok { + return nil, service.ErrApiKeyNotFound + } + clone := *key + return &clone, nil +} + +func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { + found, ok := r.byKey[key] + if !ok { + return nil, service.ErrApiKeyNotFound + } + clone := *found + return &clone, nil +} + +func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { + if key == nil { + return errors.New("nil key") + } + if _, ok := r.byID[key.ID]; !ok { + return service.ErrApiKeyNotFound + } + if key.UpdatedAt.IsZero() { + key.UpdatedAt = r.now + } + clone := *key + r.byID[clone.ID] = &clone + r.byKey[clone.Key] = &clone + return nil +} + +func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { + key, ok := r.byID[id] + if !ok { + return service.ErrApiKeyNotFound + } + delete(r.byID, id) + delete(r.byKey, key.Key) + return nil +} + +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + ids := make([]int64, 0, len(r.byID)) + for id := range r.byID { + if r.byID[id].UserID == userID { + ids = append(ids, id) + } + } + sort.Slice(ids, func(i, j int) bool { return ids[i] > ids[j] }) + + start := params.Offset() + if start > len(ids) { + start = len(ids) + } + end := start + params.Limit() + if end > len(ids) { + end = len(ids) + } + + out := make([]service.ApiKey, 0, end-start) + for _, id := range ids[start:end] { + clone := *r.byID[id] + out = append(out, clone) + } + + total := int64(len(ids)) + pageSize := params.Limit() + pages := int(math.Ceil(float64(total) / float64(pageSize))) + if pages < 1 { + pages = 1 + } + return out, &pagination.PaginationResult{ + Total: total, + Page: params.Page, + PageSize: pageSize, + Pages: pages, + }, nil +} + +func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { + var count int64 + for _, key := range r.byID { + if key.UserID == userID { + count++ + } + } + return count, nil +} + +func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { + _, ok := r.byKey[key] + return ok, nil +} + +func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubUsageLogRepo struct { + userLogs map[int64][]service.UsageLog +} + +func newStubUsageLogRepo() *stubUsageLogRepo { + return &stubUsageLogRepo{userLogs: make(map[int64][]service.UsageLog)} +} + +func (r *stubUsageLogRepo) SetUserLogs(userID int64, logs []service.UsageLog) { + r.userLogs[userID] = logs +} + +func (r *stubUsageLogRepo) Create(ctx context.Context, log *service.UsageLog) error { + return errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + logs := r.userLogs[userID] + total := int64(len(logs)) + out := paginateLogs(logs, params) + return out, paginationResult(total, params), nil +} + +func (r *stubUsageLogRepo) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + logs := r.userLogs[userID] + return logs, paginationResult(int64(len(logs)), pagination.PaginationParams{Page: 1, PageSize: 100}), nil +} + +func (r *stubUsageLogRepo) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { + return nil, errors.New("not implemented") +} + +type stubSettingRepo struct { + all map[string]string +} + +func newStubSettingRepo() *stubSettingRepo { + return &stubSettingRepo{all: make(map[string]string)} +} + +func (r *stubSettingRepo) SetAll(values map[string]string) { + r.all = make(map[string]string, len(values)) + for k, v := range values { + r.all[k] = v + } +} + +func (r *stubSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + value, ok := r.all[key] + if !ok { + return nil, service.ErrSettingNotFound + } + return &service.Setting{Key: key, Value: value}, nil +} + +func (r *stubSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + value, ok := r.all[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (r *stubSettingRepo) Set(ctx context.Context, key, value string) error { + r.all[key] = value + return nil +} + +func (r *stubSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = r.all[key] + } + return out, nil +} + +func (r *stubSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + for k, v := range settings { + r.all[k] = v + } + return nil +} + +func (r *stubSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(r.all)) + for k, v := range r.all { + out[k] = v + } + return out, nil +} + +func (r *stubSettingRepo) Delete(ctx context.Context, key string) error { + delete(r.all, key) + return nil +} + +func paginateLogs(logs []service.UsageLog, params pagination.PaginationParams) []service.UsageLog { + start := params.Offset() + if start > len(logs) { + start = len(logs) + } + end := start + params.Limit() + if end > len(logs) { + end = len(logs) + } + out := make([]service.UsageLog, 0, end-start) + out = append(out, logs[start:end]...) + return out +} + +func paginationResult(total int64, params pagination.PaginationParams) *pagination.PaginationResult { + pageSize := params.Limit() + pages := int(math.Ceil(float64(total) / float64(pageSize))) + if pages < 1 { + pages = 1 + } + return &pagination.PaginationResult{ + Total: total, + Page: params.Page, + PageSize: pageSize, + Pages: pages, + } +} + +// Ensure compile-time interface compliance. +var ( + _ service.UserRepository = (*stubUserRepo)(nil) + _ service.ApiKeyRepository = (*stubApiKeyRepo)(nil) + _ service.ApiKeyCache = (*stubApiKeyCache)(nil) + _ service.GroupRepository = (*stubGroupRepo)(nil) + _ service.UserSubscriptionRepository = (*stubUserSubscriptionRepo)(nil) + _ service.UsageLogRepository = (*stubUsageLogRepo)(nil) + _ service.SettingRepository = (*stubSettingRepo)(nil) +) diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 70e2f230..4f22d80c 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -5,7 +5,6 @@ import ( "errors" "strings" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -84,7 +83,11 @@ func validateAdminApiKey( return false } - c.Set(string(ContextKeyUser), admin) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: admin.ID, + Concurrency: admin.Concurrency, + }) + c.Set(string(ContextKeyUserRole), admin.Role) c.Set("auth_method", "admin_api_key") return true } @@ -121,12 +124,16 @@ func validateJWTForAdmin( } // 检查管理员权限 - if user.Role != model.RoleAdmin { + if !user.IsAdmin() { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") return false } - c.Set(string(ContextKeyUser), user) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: user.ID, + Concurrency: user.Concurrency, + }) + c.Set(string(ContextKeyUserRole), user.Role) c.Set("auth_method", "jwt") return true diff --git a/backend/internal/server/middleware/admin_only.go b/backend/internal/server/middleware/admin_only.go index 5500e389..2cd697a3 100644 --- a/backend/internal/server/middleware/admin_only.go +++ b/backend/internal/server/middleware/admin_only.go @@ -1,7 +1,7 @@ package middleware import ( - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -10,15 +10,14 @@ import ( // 必须在JWTAuth中间件之后使用 func AdminOnly() gin.HandlerFunc { return func(c *gin.Context) { - // 从上下文获取用户 - user, exists := GetUserFromContext(c) - if !exists { + role, ok := GetUserRoleFromContext(c) + if !ok { AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context") return } // 检查是否为管理员 - if user.Role != model.RoleAdmin { + if role != service.RoleAdmin { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") return } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 3a19e664..19d866b8 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -5,11 +5,9 @@ import ( "log" "strings" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" - "gorm.io/gorm" ) // NewApiKeyAuthMiddleware 创建 API Key 认证中间件 @@ -46,7 +44,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti // 从数据库验证API key apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, service.ErrApiKeyNotFound) { AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") return } @@ -121,28 +119,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti // 将API key和用户信息存入上下文 c.Set(string(ContextKeyApiKey), apiKey) - c.Set(string(ContextKeyUser), apiKey.User) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) c.Next() } } // GetApiKeyFromContext 从上下文中获取API key -func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) { +func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) { value, exists := c.Get(string(ContextKeyApiKey)) if !exists { return nil, false } - apiKey, ok := value.(*model.ApiKey) + apiKey, ok := value.(*service.ApiKey) return apiKey, ok } // GetSubscriptionFromContext 从上下文中获取订阅信息 -func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) { +func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) { value, exists := c.Get(string(ContextKeySubscription)) if !exists { return nil, false } - subscription, ok := value.(*model.UserSubscription) + subscription, ok := value.(*service.UserSubscription) return subscription, ok } diff --git a/backend/internal/server/middleware/auth_subject.go b/backend/internal/server/middleware/auth_subject.go new file mode 100644 index 00000000..200c7b77 --- /dev/null +++ b/backend/internal/server/middleware/auth_subject.go @@ -0,0 +1,28 @@ +package middleware + +import "github.com/gin-gonic/gin" + +// AuthSubject is the minimal authenticated identity stored in gin context. +// Decision: {UserID int64, Concurrency int} +type AuthSubject struct { + UserID int64 + Concurrency int +} + +func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) { + value, exists := c.Get(string(ContextKeyUser)) + if !exists { + return AuthSubject{}, false + } + subject, ok := value.(AuthSubject) + return subject, ok +} + +func GetUserRoleFromContext(c *gin.Context) (string, bool) { + value, exists := c.Get(string(ContextKeyUserRole)) + if !exists { + return "", false + } + role, ok := value.(string) + return role, ok +} diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 10ea8e7e..09239d0c 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -4,7 +4,6 @@ import ( "errors" "strings" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) return } - // 将用户信息存入上下文 - c.Set(string(ContextKeyUser), user) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: user.ID, + Concurrency: user.Concurrency, + }) + c.Set(string(ContextKeyUserRole), user.Role) c.Next() } } -// GetUserFromContext 从上下文中获取用户 -func GetUserFromContext(c *gin.Context) (*model.User, bool) { - value, exists := c.Get(string(ContextKeyUser)) - if !exists { - return nil, false - } - user, ok := value.(*model.User) - return user, ok -} +// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go. diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index aeda0387..1af8dbef 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -8,6 +8,8 @@ type ContextKey string const ( // ContextKeyUser 用户上下文键 ContextKeyUser ContextKey = "user" + // ContextKeyUserRole 当前用户角色(string) + ContextKeyUserRole ContextKey = "user_role" // ContextKeyApiKey API密钥上下文键 ContextKeyApiKey ContextKey = "api_key" // ContextKeySubscription 订阅上下文键 diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go new file mode 100644 index 00000000..373ad4a9 --- /dev/null +++ b/backend/internal/service/account.go @@ -0,0 +1,324 @@ +package service + +import "time" + +type Account struct { + ID int64 + Name string + Platform string + Type string + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency int + Priority int + Status string + ErrorMessage string + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + + Schedulable bool + + RateLimitedAt *time.Time + RateLimitResetAt *time.Time + OverloadUntil *time.Time + + SessionWindowStart *time.Time + SessionWindowEnd *time.Time + SessionWindowStatus string + + Proxy *Proxy + AccountGroups []AccountGroup + GroupIDs []int64 + Groups []*Group +} + +func (a *Account) IsActive() bool { + return a.Status == StatusActive +} + +func (a *Account) IsSchedulable() bool { + if !a.IsActive() || !a.Schedulable { + return false + } + now := time.Now() + if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) { + return false + } + if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) { + return false + } + return true +} + +func (a *Account) IsRateLimited() bool { + if a.RateLimitResetAt == nil { + return false + } + return time.Now().Before(*a.RateLimitResetAt) +} + +func (a *Account) IsOverloaded() bool { + if a.OverloadUntil == nil { + return false + } + return time.Now().Before(*a.OverloadUntil) +} + +func (a *Account) IsOAuth() bool { + return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken +} + +func (a *Account) CanGetUsage() bool { + return a.Type == AccountTypeOAuth +} + +func (a *Account) GetCredential(key string) string { + if a.Credentials == nil { + return "" + } + if v, ok := a.Credentials[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func (a *Account) GetModelMapping() map[string]string { + if a.Credentials == nil { + return nil + } + raw, ok := a.Credentials["model_mapping"] + if !ok || raw == nil { + return nil + } + if m, ok := raw.(map[string]any); ok { + result := make(map[string]string) + for k, v := range m { + if s, ok := v.(string); ok { + result[k] = s + } + } + if len(result) > 0 { + return result + } + } + return nil +} + +func (a *Account) IsModelSupported(requestedModel string) bool { + mapping := a.GetModelMapping() + if len(mapping) == 0 { + return true + } + _, exists := mapping[requestedModel] + return exists +} + +func (a *Account) GetMappedModel(requestedModel string) string { + mapping := a.GetModelMapping() + if len(mapping) == 0 { + return requestedModel + } + if mappedModel, exists := mapping[requestedModel]; exists { + return mappedModel + } + return requestedModel +} + +func (a *Account) GetBaseURL() string { + if a.Type != AccountTypeApiKey { + return "" + } + baseURL := a.GetCredential("base_url") + if baseURL == "" { + return "https://api.anthropic.com" + } + return baseURL +} + +func (a *Account) GetExtraString(key string) string { + if a.Extra == nil { + return "" + } + if v, ok := a.Extra[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func (a *Account) IsCustomErrorCodesEnabled() bool { + if a.Type != AccountTypeApiKey || a.Credentials == nil { + return false + } + if v, ok := a.Credentials["custom_error_codes_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +func (a *Account) GetCustomErrorCodes() []int { + if a.Credentials == nil { + return nil + } + raw, ok := a.Credentials["custom_error_codes"] + if !ok || raw == nil { + return nil + } + if arr, ok := raw.([]any); ok { + result := make([]int, 0, len(arr)) + for _, v := range arr { + if f, ok := v.(float64); ok { + result = append(result, int(f)) + } + } + return result + } + return nil +} + +func (a *Account) ShouldHandleErrorCode(statusCode int) bool { + if !a.IsCustomErrorCodesEnabled() { + return true + } + codes := a.GetCustomErrorCodes() + if len(codes) == 0 { + return true + } + for _, code := range codes { + if code == statusCode { + return true + } + } + return false +} + +func (a *Account) IsInterceptWarmupEnabled() bool { + if a.Credentials == nil { + return false + } + if v, ok := a.Credentials["intercept_warmup_requests"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +func (a *Account) IsOpenAI() bool { + return a.Platform == PlatformOpenAI +} + +func (a *Account) IsAnthropic() bool { + return a.Platform == PlatformAnthropic +} + +func (a *Account) IsOpenAIOAuth() bool { + return a.IsOpenAI() && a.Type == AccountTypeOAuth +} + +func (a *Account) IsOpenAIApiKey() bool { + return a.IsOpenAI() && a.Type == AccountTypeApiKey +} + +func (a *Account) GetOpenAIBaseURL() string { + if !a.IsOpenAI() { + return "" + } + if a.Type == AccountTypeApiKey { + baseURL := a.GetCredential("base_url") + if baseURL != "" { + return baseURL + } + } + return "https://api.openai.com" +} + +func (a *Account) GetOpenAIAccessToken() string { + if !a.IsOpenAI() { + return "" + } + return a.GetCredential("access_token") +} + +func (a *Account) GetOpenAIRefreshToken() string { + if !a.IsOpenAIOAuth() { + return "" + } + return a.GetCredential("refresh_token") +} + +func (a *Account) GetOpenAIIDToken() string { + if !a.IsOpenAIOAuth() { + return "" + } + return a.GetCredential("id_token") +} + +func (a *Account) GetOpenAIApiKey() string { + if !a.IsOpenAIApiKey() { + return "" + } + return a.GetCredential("api_key") +} + +func (a *Account) GetOpenAIUserAgent() string { + if !a.IsOpenAI() { + return "" + } + return a.GetCredential("user_agent") +} + +func (a *Account) GetChatGPTAccountID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return a.GetCredential("chatgpt_account_id") +} + +func (a *Account) GetChatGPTUserID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return a.GetCredential("chatgpt_user_id") +} + +func (a *Account) GetOpenAIOrganizationID() string { + if !a.IsOpenAIOAuth() { + return "" + } + return a.GetCredential("organization_id") +} + +func (a *Account) GetOpenAITokenExpiresAt() *time.Time { + if !a.IsOpenAIOAuth() { + return nil + } + expiresAtStr := a.GetCredential("expires_at") + if expiresAtStr == "" { + return nil + } + t, err := time.Parse(time.RFC3339, expiresAtStr) + if err != nil { + if v, ok := a.Credentials["expires_at"].(float64); ok { + tt := time.Unix(int64(v), 0) + return &tt + } + return nil + } + return &t +} + +func (a *Account) IsOpenAITokenExpired() bool { + expiresAt := a.GetOpenAITokenExpiresAt() + if expiresAt == nil { + return false + } + return time.Now().Add(60 * time.Second).After(*expiresAt) +} diff --git a/backend/internal/service/account_group.go b/backend/internal/service/account_group.go new file mode 100644 index 00000000..ab702a08 --- /dev/null +++ b/backend/internal/service/account_group.go @@ -0,0 +1,13 @@ +package service + +import "time" + +type AccountGroup struct { + AccountID int64 + GroupID int64 + Priority int + CreatedAt time.Time + + Account *Account + Group *Group +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index e968ea57..a5b9cd7f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -6,7 +6,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -15,29 +14,29 @@ var ( ) type AccountRepository interface { - Create(ctx context.Context, account *model.Account) error - GetByID(ctx context.Context, id int64) (*model.Account, error) + Create(ctx context.Context, account *Account) error + GetByID(ctx context.Context, id int64) (*Account, error) // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. - GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) - Update(ctx context.Context, account *model.Account) error + GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error - List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) - ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) - ListActive(ctx context.Context) ([]model.Account, error) - ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) + List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) + ListByGroup(ctx context.Context, groupID int64) ([]Account, error) + ListActive(ctx context.Context) ([]Account, error) + ListByPlatform(ctx context.Context, platform string) ([]Account, error) UpdateLastUsed(ctx context.Context, id int64) error SetError(ctx context.Context, id int64, errorMsg string) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error - ListSchedulable(ctx context.Context) ([]model.Account, error) - ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) - ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) - ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) + ListSchedulable(ctx context.Context) ([]Account, error) + ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) + ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) + ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error @@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) } // Create 创建账号 -func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) { +func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { for _, groupID := range req.GroupIDs { @@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( } // 创建账号 - account := &model.Account{ + account := &Account{ Name: req.Name, Platform: req.Platform, Type: req.Type, @@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( ProxyID: req.ProxyID, Concurrency: req.Concurrency, Priority: req.Priority, - Status: model.StatusActive, + Status: StatusActive, } if err := s.accountRepo.Create(ctx, account); err != nil { @@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( } // GetByID 根据ID获取账号 -func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { +func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get account: %w", err) @@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, } // List 获取账号列表 -func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { +func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { accounts, pagination, err := s.accountRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list accounts: %w", err) @@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP } // ListByPlatform 根据平台获取账号列表 -func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { +func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) { accounts, err := s.accountRepo.ListByPlatform(ctx, platform) if err != nil { return nil, fmt.Errorf("list accounts by platform: %w", err) @@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([ } // ListByGroup 根据分组获取账号列表 -func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { +func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { accounts, err := s.accountRepo.ListByGroup(ctx, groupID) if err != nil { return nil, fmt.Errorf("list accounts by group: %w", err) @@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode } // Update 更新账号 -func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { +func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get account: %w", err) @@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { // 根据平台执行不同的测试逻辑 switch account.Platform { - case model.PlatformAnthropic: + case PlatformAnthropic: // TODO: 测试Anthropic API凭证 return nil - case model.PlatformOpenAI: + case PlatformOpenAI: // TODO: 测试OpenAI API凭证 return nil - case model.PlatformGemini: + case PlatformGemini: // TODO: 测试Gemini API凭证 return nil default: diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 782aa95c..f4daa8b1 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -14,7 +14,6 @@ import ( "strings" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/gin-gonic/gin" @@ -127,7 +126,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int } // testClaudeAccountConnection tests an Anthropic Claude account's connection -func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error { +func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() // Determine the model to use @@ -254,7 +253,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account } // testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error { +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { ctx := c.Request.Context() // Default to openai.DefaultTestModel for OpenAI testing diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6f89b600..642c8e09 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -7,24 +7,23 @@ import ( "sync" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) type UsageLogRepository interface { - Create(ctx context.Context, log *model.UsageLog) error - GetByID(ctx context.Context, id int64) (*model.UsageLog, error) + Create(ctx context.Context, log *UsageLog) error + GetByID(ctx context.Context, id int64) (*UsageLog, error) Delete(ctx context.Context, id int64) error - ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) - ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) - ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) - ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) - ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) - ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) - ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) @@ -44,7 +43,7 @@ type UsageLogRepository interface { GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) // Admin usage listing/stats - ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) // Account stats @@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U } // Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API) - if account.Type == model.AccountTypeSetupToken { + if account.Type == AccountTypeSetupToken { usage := s.estimateSetupTokenUsage(account) // 添加窗口统计 s.addWindowStats(ctx, account, usage) @@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U } // addWindowStats 为usage数据添加窗口期统计 -func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) { +func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { if usage.FiveHour == nil { return } @@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI } // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 -func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { +func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) { accessToken := account.GetCredential("access_token") if accessToken == "" { return nil, fmt.Errorf("no access token available") @@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA } // estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量 -func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo { +func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo { info := &UsageInfo{} // 如果有session_window信息 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index d676ac8f..f1eb0fc6 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -7,62 +7,61 @@ import ( "log" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) // AdminService interface defines admin management operations type AdminService interface { // User management - ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) - GetUser(ctx context.Context, id int64) (*model.User, error) - CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) - UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) + ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) + GetUser(ctx context.Context, id int64) (*User, error) + CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) + UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) DeleteUser(ctx context.Context, id int64) error - UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) - GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) + UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) + GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) - GetAllGroups(ctx context.Context) ([]model.Group, error) - GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) - GetGroup(ctx context.Context, id int64) (*model.Group, error) - CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) - UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) + ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) + GetAllGroups(ctx context.Context) ([]Group, error) + GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) + GetGroup(ctx context.Context, id int64) (*Group, error) + CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) + UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error - GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) + GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) - GetAccount(ctx context.Context, id int64) (*model.Account, error) - CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) - UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + GetAccount(ctx context.Context, id int64) (*Account, error) + CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) + UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) DeleteAccount(ctx context.Context, id int64) error - RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) - ClearAccountError(ctx context.Context, id int64) (*model.Account, error) - SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) + RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) + ClearAccountError(ctx context.Context, id int64) (*Account, error) + SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) // Proxy management - ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) - GetAllProxies(ctx context.Context) ([]model.Proxy, error) - GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) - GetProxy(ctx context.Context, id int64) (*model.Proxy, error) - CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) - UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) + ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) + GetAllProxies(ctx context.Context) ([]Proxy, error) + GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) + GetProxy(ctx context.Context, id int64) (*Proxy, error) + CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) + UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) DeleteProxy(ctx context.Context, id int64) error - GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) + GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) // Redeem code management - ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) - GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) - GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) + ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) + GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) DeleteRedeemCode(ctx context.Context, id int64) error BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) - ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) + ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) } // Input types for admin operations @@ -252,7 +251,7 @@ func NewAdminService( } // User management implementations -func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { +func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) if err != nil { @@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st return users, result.Total, nil } -func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) { +func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { return s.userRepo.GetByID(ctx, id) } -func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) { - user := &model.User{ - Email: input.Email, - Username: input.Username, - Wechat: input.Wechat, - Notes: input.Notes, - Role: "user", // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: model.StatusActive, +func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { + user := &User{ + Email: input.Email, + Username: input.Username, + Wechat: input.Wechat, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu return user, nil } -func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) { +func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { - code, err := model.GenerateRedeemCode() + code, err := GenerateRedeemCode() if err != nil { log.Printf("failed to generate adjustment redeem code: %v", err) return user, nil } - adjustmentRecord := &model.RedeemCode{ + adjustmentRecord := &RedeemCode{ Code: code, - Type: model.AdjustmentTypeAdminConcurrency, + Type: AdjustmentTypeAdminConcurrency, Value: float64(concurrencyDiff), - Status: model.StatusUsed, + Status: StatusUsed, UsedBy: &user.ID, } now := time.Now() @@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return s.userRepo.Delete(ctx, id) } -func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) { +func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, err @@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { - code, err := model.GenerateRedeemCode() + code, err := GenerateRedeemCode() if err != nil { log.Printf("failed to generate adjustment redeem code: %v", err) return user, nil } - adjustmentRecord := &model.RedeemCode{ + adjustmentRecord := &RedeemCode{ Code: code, - Type: model.AdjustmentTypeAdminBalance, + Type: AdjustmentTypeAdminBalance, Value: balanceDiff, - Status: model.StatusUsed, + Status: StatusUsed, UsedBy: &user.ID, Notes: notes, } @@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, return user, nil } -func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { +func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { @@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, } // Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) if err != nil { @@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p return groups, result.Total, nil } -func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) { +func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { return s.groupRepo.ListActive(ctx) } -func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) { +func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { return s.groupRepo.ListActiveByPlatform(ctx, platform) } -func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) { +func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { return s.groupRepo.GetByID(ctx, id) } -func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) { +func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { platform := input.Platform if platform == "" { - platform = model.PlatformAnthropic + platform = PlatformAnthropic } subscriptionType := input.SubscriptionType if subscriptionType == "" { - subscriptionType = model.SubscriptionTypeStandard + subscriptionType = SubscriptionTypeStandard } - group := &model.Group{ + group := &Group{ Name: input.Name, Description: input.Description, Platform: platform, RateMultiplier: input.RateMultiplier, IsExclusive: input.IsExclusive, - Status: model.StatusActive, + Status: StatusActive, SubscriptionType: subscriptionType, DailyLimitUSD: input.DailyLimitUSD, WeeklyLimitUSD: input.WeeklyLimitUSD, @@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return group, nil } -func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) { +func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { return nil } -func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { +func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) if err != nil { @@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) if err != nil { @@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, return accounts, result.Total, nil } -func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) { +func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { return s.accountRepo.GetByID(ctx, id) } -func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) { - account := &model.Account{ +func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { + account := &Account{ Name: input.Name, Platform: input.Platform, Type: input.Type, - Credentials: model.JSONB(input.Credentials), - Extra: model.JSONB(input.Extra), + Credentials: input.Credentials, + Extra: input.Extra, ProxyID: input.ProxyID, Concurrency: input.Concurrency, Priority: input.Priority, - Status: model.StatusActive, + Status: StatusActive, } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err @@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return account, nil } -func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) { +func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Type = input.Type } if len(input.Credentials) > 0 { - account.Credentials = model.JSONB(input.Credentials) + account.Credentials = input.Credentials } if len(input.Extra) > 0 { - account.Extra = model.JSONB(input.Extra) + account.Extra = input.Extra } if input.ProxyID != nil { account.ProxyID = input.ProxyID @@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { return s.accountRepo.Delete(ctx, id) } -func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) { +func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int return account, nil } -func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) { +func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { return nil, err } - account.Status = model.StatusActive + account.Status = StatusActive account.ErrorMessage = "" if err := s.accountRepo.Update(ctx, account); err != nil { return nil, err @@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo return account, nil } -func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) { +func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } @@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, } // Proxy management implementations -func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { +func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) if err != nil { @@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, return proxies, result.Total, nil } -func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) { +func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { return s.proxyRepo.ListActive(ctx) } -func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { +func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { return s.proxyRepo.ListActiveWithAccountCount(ctx) } -func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) { +func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { return s.proxyRepo.GetByID(ctx, id) } -func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) { - proxy := &model.Proxy{ +func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { + proxy := &Proxy{ Name: input.Name, Protocol: input.Protocol, Host: input.Host, Port: input.Port, Username: input.Username, Password: input.Password, - Status: model.StatusActive, + Status: StatusActive, } if err := s.proxyRepo.Create(ctx, proxy); err != nil { return nil, err @@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn return proxy, nil } -func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) { +func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { return s.proxyRepo.Delete(ctx, id) } -func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) { +func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) { // Return mock data for now - would need a dedicated repository method - return []model.Account{}, 0, nil + return []Account{}, 0, nil } func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { @@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po } // Redeem code management implementations -func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { +func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) if err != nil { @@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i return codes, result.Total, nil } -func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { +func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { return s.redeemCodeRepo.GetByID(ctx, id) } -func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) { +func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { // 如果是订阅类型,验证必须有 GroupID - if input.Type == model.RedeemTypeSubscription { + if input.Type == RedeemTypeSubscription { if input.GroupID == nil { return nil, errors.New("group_id is required for subscription type") } @@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener } } - codes := make([]model.RedeemCode, 0, input.Count) + codes := make([]RedeemCode, 0, input.Count) for i := 0; i < input.Count; i++ { - codeValue, err := model.GenerateRedeemCode() + codeValue, err := GenerateRedeemCode() if err != nil { return nil, err } - code := model.RedeemCode{ + code := RedeemCode{ Code: codeValue, Type: input.Type, Value: input.Value, - Status: model.StatusUnused, + Status: StatusUnused, } // 订阅类型专用字段 - if input.Type == model.RedeemTypeSubscription { + if input.Type == RedeemTypeSubscription { code.GroupID = input.GroupID code.ValidityDays = input.ValidityDays if code.ValidityDays <= 0 { @@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int return deleted, nil } -func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) { +func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemCodeRepo.GetByID(ctx, id) if err != nil { return nil, err } - code.Status = model.StatusExpired + code.Status = StatusExpired if err := s.redeemCodeRepo.Update(ctx, code); err != nil { return nil, err } diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go new file mode 100644 index 00000000..e76f0f8e --- /dev/null +++ b/backend/internal/service/api_key.go @@ -0,0 +1,20 @@ +package service + +import "time" + +type ApiKey struct { + ID int64 + UserID int64 + Key string + Name string + GroupID *int64 + Status string + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group +} + +func (k *ApiKey) IsActive() bool { + return k.Status == StatusActive +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 788d226e..ac236175 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -10,7 +10,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/redis/go-redis/v9" @@ -30,17 +29,17 @@ const ( ) type ApiKeyRepository interface { - Create(ctx context.Context, key *model.ApiKey) error - GetByID(ctx context.Context, id int64) (*model.ApiKey, error) - GetByKey(ctx context.Context, key string) (*model.ApiKey, error) - Update(ctx context.Context, key *model.ApiKey) error + Create(ctx context.Context, key *ApiKey) error + GetByID(ctx context.Context, id int64) (*ApiKey, error) + GetByKey(ctx context.Context, key string) (*ApiKey, error) + Update(ctx context.Context, key *ApiKey) error Delete(ctx context.Context, id int64) error - ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) + ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) CountByUserID(ctx context.Context, userID int64) (int64, error) ExistsByKey(ctx context.Context, key string) (bool, error) - ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) - SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) + ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) + SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) } @@ -168,7 +167,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in // canUserBindGroup 检查用户是否可以绑定指定分组 // 对于订阅类型分组:检查用户是否有有效订阅 // 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑 -func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool { +func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool { // 订阅类型分组:需要有效订阅 if group.IsSubscriptionType() { _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID) @@ -179,7 +178,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, } // Create 创建API Key -func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) { +func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) { // 验证用户存在 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { @@ -235,12 +234,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK } // 创建API Key记录 - apiKey := &model.ApiKey{ + apiKey := &ApiKey{ UserID: userID, Key: key, Name: req.Name, GroupID: req.GroupID, - Status: model.StatusActive, + Status: StatusActive, } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -251,7 +250,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK } // List 获取用户的API Key列表 -func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { +func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) { keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { return nil, nil, fmt.Errorf("list api keys: %w", err) @@ -260,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio } // GetByID 根据ID获取API Key -func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { +func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get api key: %w", err) @@ -269,7 +268,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e } // GetByKey 根据Key字符串获取API Key(用于认证) -func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { +func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) { // 尝试从Redis缓存获取 cacheKey := fmt.Sprintf("apikey:%s", key) @@ -289,7 +288,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey } // Update 更新API Key -func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { +func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get api key: %w", err) @@ -364,7 +363,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro } // ValidateKey 验证API Key是否有效(用于认证中间件) -func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) { +func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) { // 获取API Key apiKey, err := s.GetByKey(ctx, key) if err != nil { @@ -408,7 +407,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 返回用户可以选择的分组: // - 标准类型分组:公开的(非专属)或用户被明确允许的 // - 订阅类型分组:用户有有效订阅的 -func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) { +func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) { // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { @@ -434,7 +433,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ } // 过滤出用户有权限的分组 - availableGroups := make([]model.Group, 0) + availableGroups := make([]Group, 0) for _, group := range allGroups { if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) { availableGroups = append(availableGroups, group) @@ -445,7 +444,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ } // canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据) -func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool { +func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool { // 订阅类型分组:需要有效订阅 if group.IsSubscriptionType() { return subscribedGroupIDs[group.ID] @@ -454,7 +453,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model. return user.CanBindGroup(group.ID, group.IsExclusive) } -func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { +func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) { keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit) if err != nil { return nil, fmt.Errorf("search api keys: %w", err) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index e6d29e09..e2d08f2c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -9,7 +9,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" @@ -64,12 +63,12 @@ func NewAuthService( } // Register 用户注册,返回token和用户 -func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { +func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "") } // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 -func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) { +func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { // 检查是否开放注册 if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled @@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } // 创建用户 - user := &model.User{ + user := &User{ Email: email, PasswordHash: hashedPassword, - Role: model.RoleUser, + Role: RoleUser, Balance: defaultBalance, Concurrency: defaultConcurrency, - Status: model.StatusActive, + Status: StatusActive, } if err := s.userRepo.Create(ctx, user); err != nil { @@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool { } // Login 用户登录,返回JWT token -func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) { +func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) { // 查找用户 user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { @@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { } // GenerateToken 生成JWT token -func (s *AuthService) GenerateToken(user *model.User) (string, error) { +func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 70741d56..18f125ca 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -7,7 +7,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" ) // 错误定义 @@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID // CheckBillingEligibility 检查用户是否有资格发起请求 // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) -func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error { +func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { // 判断计费模式 isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil @@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI } // checkSubscriptionEligibility 检查订阅模式资格 -func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error { +func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error { // 获取订阅缓存数据 subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID) if err != nil { @@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, } // 检查订阅状态 - if subData.Status != model.SubscriptionStatusActive { + if subData.Status != SubscriptionStatusActive { return ErrSubscriptionInvalid } @@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, } // checkSubscriptionLimitsFallback 降级检查订阅限额 -func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error { +func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error { if subscription == nil { return ErrSubscriptionInvalid } diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index e1f9d252..90a63f10 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -11,8 +11,6 @@ import ( "net/url" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/model" ) type CRSSyncService struct { @@ -180,7 +178,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ), } - var proxies []model.Proxy + var proxies []Proxy if input.SyncProxies { proxies, _ = s.proxyRepo.ListActive(ctx) } @@ -197,7 +195,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput if targetType == "" { targetType = "oauth" } - if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken { + if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken { item.Action = "skipped" item.Error = "unsupported authType: " + targetType result.Skipped++ @@ -268,12 +266,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformAnthropic, + Platform: PlatformAnthropic, Type: targetType, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: concurrency, Priority: priority, @@ -288,7 +286,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } // 🔄 Refresh OAuth token after creation - if targetType == model.AccountTypeOAuth { + if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { account.Credentials = refreshedCreds _ = s.accountRepo.Update(ctx, account) @@ -301,11 +299,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } // Update existing - existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Extra = mergeMap(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformAnthropic + existing.Platform = PlatformAnthropic existing.Type = targetType - existing.Credentials = mergeJSONB(existing.Credentials, credentials) + existing.Credentials = mergeMap(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID } @@ -323,7 +321,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } // 🔄 Refresh OAuth token after update - if targetType == model.AccountTypeOAuth { + if targetType == AccountTypeOAuth { if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { existing.Credentials = refreshedCreds _ = s.accountRepo.Update(ctx, existing) @@ -385,12 +383,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformAnthropic, - Type: model.AccountTypeApiKey, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Platform: PlatformAnthropic, + Type: AccountTypeApiKey, + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: concurrency, Priority: priority, @@ -410,11 +408,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } - existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Extra = mergeMap(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformAnthropic - existing.Type = model.AccountTypeApiKey - existing.Credentials = mergeJSONB(existing.Credentials, credentials) + existing.Platform = PlatformAnthropic + existing.Type = AccountTypeApiKey + existing.Credentials = mergeMap(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID } @@ -508,12 +506,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformOpenAI, - Type: model.AccountTypeOAuth, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: concurrency, Priority: priority, @@ -538,11 +536,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } - existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Extra = mergeMap(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformOpenAI - existing.Type = model.AccountTypeOAuth - existing.Credentials = mergeJSONB(existing.Credentials, credentials) + existing.Platform = PlatformOpenAI + existing.Type = AccountTypeOAuth + existing.Credentials = mergeMap(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID } @@ -629,12 +627,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { - account := &model.Account{ + account := &Account{ Name: defaultName(src.Name, src.ID), - Platform: model.PlatformOpenAI, - Type: model.AccountTypeApiKey, - Credentials: model.JSONB(credentials), - Extra: model.JSONB(extra), + Platform: PlatformOpenAI, + Type: AccountTypeApiKey, + Credentials: credentials, + Extra: extra, ProxyID: proxyID, Concurrency: concurrency, Priority: priority, @@ -654,11 +652,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput continue } - existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Extra = mergeMap(existing.Extra, extra) existing.Name = defaultName(src.Name, src.ID) - existing.Platform = model.PlatformOpenAI - existing.Type = model.AccountTypeApiKey - existing.Credentials = mergeJSONB(existing.Credentials, credentials) + existing.Platform = PlatformOpenAI + existing.Type = AccountTypeApiKey + existing.Credentials = mergeMap(existing.Credentials, credentials) if proxyID != nil { existing.ProxyID = proxyID } @@ -683,9 +681,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput return result, nil } -// mergeJSONB merges two JSONB maps without removing keys that are absent in updates. -func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB { - out := make(model.JSONB) +func mergeMap(existing map[string]any, updates map[string]any) map[string]any { + out := make(map[string]any, len(existing)+len(updates)) for k, v := range existing { out[k] = v } @@ -695,7 +692,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB { return out } -func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) { +func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) { if !enabled || src == nil { return nil, nil } @@ -731,14 +728,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac } // Create new proxy - proxy := &model.Proxy{ + proxy := &Proxy{ Name: defaultProxyName(defaultName, protocol, host, port), Protocol: protocol, Host: host, Port: port, Username: username, Password: password, - Status: model.StatusActive, + Status: StatusActive, } if err := s.proxyRepo.Create(ctx, proxy); err != nil { return nil, err @@ -897,8 +894,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT // refreshOAuthToken attempts to refresh OAuth token for a synced account // Returns updated credentials or nil if refresh failed/not applicable -func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB { - if account.Type != model.AccountTypeOAuth { +func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any { + if account.Type != AccountTypeOAuth { return nil } @@ -906,7 +903,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A var err error switch account.Platform { - case model.PlatformAnthropic: + case PlatformAnthropic: if s.oauthService == nil { return nil } @@ -931,7 +928,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A newCredentials["scope"] = tokenInfo.Scope } } - case model.PlatformOpenAI: + case PlatformOpenAI: if s.openaiOAuthService == nil { return nil } @@ -956,5 +953,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A return nil } - return model.JSONB(newCredentials) + return newCredentials } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go new file mode 100644 index 00000000..b0f3fc9e --- /dev/null +++ b/backend/internal/service/domain_constants.go @@ -0,0 +1,96 @@ +package service + +// Status constants +const ( + StatusActive = "active" + StatusDisabled = "disabled" + StatusError = "error" + StatusUnused = "unused" + StatusUsed = "used" + StatusExpired = "expired" +) + +// Role constants +const ( + RoleAdmin = "admin" + RoleUser = "user" +) + +// Platform constants +const ( + PlatformAnthropic = "anthropic" + PlatformOpenAI = "openai" + PlatformGemini = "gemini" +) + +// Account type constants +const ( + AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) + AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) + AccountTypeApiKey = "apikey" // API Key类型账号 +) + +// Redeem type constants +const ( + RedeemTypeBalance = "balance" + RedeemTypeConcurrency = "concurrency" + RedeemTypeSubscription = "subscription" +) + +// Admin adjustment type constants +const ( + AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额 + AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数 +) + +// Group subscription type constants +const ( + SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费) + SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制) +) + +// Subscription status constants +const ( + SubscriptionStatusActive = "active" + SubscriptionStatusExpired = "expired" + SubscriptionStatusSuspended = "suspended" +) + +// Setting keys +const ( + // 注册设置 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + + // 邮件服务设置 + SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 + SettingKeySmtpPort = "smtp_port" // SMTP端口 + SettingKeySmtpUsername = "smtp_username" // SMTP用户名 + SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) + SettingKeySmtpFrom = "smtp_from" // 发件人地址 + SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 + SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS + + // Cloudflare Turnstile 设置 + SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 + SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key + SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key + + // OEM设置 + SettingKeySiteName = "site_name" // 网站名称 + SettingKeySiteLogo = "site_logo" // 网站Logo (base64) + SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 + SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入) + SettingKeyContactInfo = "contact_info" // 客服联系方式 + SettingKeyDocUrl = "doc_url" // 文档链接 + + // 默认配置 + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + + // 管理员 API Key + SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) +) + +// Admin API Key prefix (distinct from user "sk-" keys) +const AdminApiKeyPrefix = "admin-" diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 27c68c52..7b4db611 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -11,7 +11,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" ) var ( @@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ // GetSmtpConfig 从数据库获取SMTP配置 func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { keys := []string{ - model.SettingKeySmtpHost, - model.SettingKeySmtpPort, - model.SettingKeySmtpUsername, - model.SettingKeySmtpPassword, - model.SettingKeySmtpFrom, - model.SettingKeySmtpFromName, - model.SettingKeySmtpUseTLS, + SettingKeySmtpHost, + SettingKeySmtpPort, + SettingKeySmtpUsername, + SettingKeySmtpPassword, + SettingKeySmtpFrom, + SettingKeySmtpFromName, + SettingKeySmtpUseTLS, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) { return nil, fmt.Errorf("get smtp settings: %w", err) } - host := settings[model.SettingKeySmtpHost] + host := settings[SettingKeySmtpHost] if host == "" { return nil, ErrEmailNotConfigured } port := 587 // 默认端口 - if portStr := settings[model.SettingKeySmtpPort]; portStr != "" { + if portStr := settings[SettingKeySmtpPort]; portStr != "" { if p, err := strconv.Atoi(portStr); err == nil { port = p } } - useTLS := settings[model.SettingKeySmtpUseTLS] == "true" + useTLS := settings[SettingKeySmtpUseTLS] == "true" return &SmtpConfig{ Host: host, Port: port, - Username: settings[model.SettingKeySmtpUsername], - Password: settings[model.SettingKeySmtpPassword], - From: settings[model.SettingKeySmtpFrom], - FromName: settings[model.SettingKeySmtpFromName], + Username: settings[SettingKeySmtpUsername], + Password: settings[SettingKeySmtpPassword], + From: settings[SettingKeySmtpFrom], + FromName: settings[SettingKeySmtpFromName], UseTLS: useTLS, }, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5b5ad7c8..4da6bd1c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -17,7 +17,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte } // SelectAccount 选择账号(粘性会话+优先级) -func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { +func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") } // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) -func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { +func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { // 1. 查询粘性会话 if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) @@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int } // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) - var accounts []model.Account + var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } // 3. 按优先级+最久未用选择(考虑模型支持) - var selected *model.Account + var selected *Account for i := range accounts { acc := &accounts[i] // 检查模型支持 @@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int } // GetAccessToken 获取账号凭证 -func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { +func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { - case model.AccountTypeOAuth, model.AccountTypeSetupToken: + case AccountTypeOAuth, AccountTypeSetupToken: // Both oauth and setup-token use OAuth token flow return s.getOAuthToken(ctx, account) - case model.AccountTypeApiKey: + case AccountTypeApiKey: apiKey := account.GetCredential("api_key") if apiKey == "" { return "", "", errors.New("api_key not found in credentials") @@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco } } -func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) { +func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { accessToken := account.GetCredential("access_token") if accessToken == "" { return "", "", errors.New("access_token not found in credentials") @@ -372,10 +371,7 @@ const ( retryDelay = 3 * time.Second // 重试等待时间 ) -// shouldRetryUpstreamError 判断是否应该重试上游错误 -// OAuth/Setup Token 账号:仅 403 重试 -// API Key 账号:未配置的错误码重试 -func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool { +func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { // OAuth/Setup Token 账号:仅 403 重试 if account.IsOAuth() { return statusCode == 403 @@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status } // Forward 转发请求到Claude API -func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { startTime := time.Now() // 解析请求获取model和stream @@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m // 应用模型映射(仅对apikey类型账号) originalModel := req.Model - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { mappedModel := account.GetMappedModel(req.Model) if mappedModel != req.Model { // 替换请求体中的模型名 @@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { baseURL := account.GetBaseURL() targetURL = baseURL + "/v1/messages" } @@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str return claude.DefaultBetaHeader } -func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { +func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) // 处理上游错误,标记账号状态 @@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res // handleRetryExhaustedError 处理重试耗尽后的错误 // OAuth 403:标记账号异常 // API Key 未配置错误码:仅返回错误,不标记账号 -func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { +func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) statusCode := resp.StatusCode @@ -717,7 +713,7 @@ type streamingResult struct { firstTokenMs *int } -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { } } -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) { +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult - ApiKey *model.ApiKey - User *model.User - Account *model.Account - Subscription *model.UserSubscription // 可选:订阅信息 + ApiKey *ApiKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := model.BillingTypeBalance + billingType := BillingTypeBalance if isSubscriptionBilling { - billingType = model.BillingTypeSubscription + billingType = BillingTypeSubscription } // 创建使用日志 durationMs := int(result.Duration.Milliseconds()) - usageLog := &model.UsageLog{ + usageLog := &UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, @@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 -func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error { +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error { // 应用模型映射(仅对 apikey 类型账号) - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { var req struct { Model string `json:"model"` } @@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL - if account.Type == model.AccountTypeApiKey { + if account.Type == AccountTypeApiKey { baseURL := account.GetBaseURL() targetURL = baseURL + "/v1/messages/count_tokens" } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go new file mode 100644 index 00000000..f1e36b89 --- /dev/null +++ b/backend/internal/service/group.go @@ -0,0 +1,48 @@ +package service + +import "time" + +type Group struct { + ID int64 + Name string + Description string + Platform string + RateMultiplier float64 + IsExclusive bool + Status string + + SubscriptionType string + DailyLimitUSD *float64 + WeeklyLimitUSD *float64 + MonthlyLimitUSD *float64 + + CreatedAt time.Time + UpdatedAt time.Time + + AccountGroups []AccountGroup + AccountCount int64 +} + +func (g *Group) IsActive() bool { + return g.Status == StatusActive +} + +func (g *Group) IsSubscriptionType() bool { + return g.SubscriptionType == SubscriptionTypeSubscription +} + +func (g *Group) IsFreeSubscription() bool { + return g.IsSubscriptionType() && g.RateMultiplier == 0 +} + +func (g *Group) HasDailyLimit() bool { + return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 +} + +func (g *Group) HasWeeklyLimit() bool { + return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0 +} + +func (g *Group) HasMonthlyLimit() bool { + return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0 +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index ea9bd24d..886c0a3a 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -5,7 +5,6 @@ import ( "fmt" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -15,16 +14,16 @@ var ( ) type GroupRepository interface { - Create(ctx context.Context, group *model.Group) error - GetByID(ctx context.Context, id int64) (*model.Group, error) - Update(ctx context.Context, group *model.Group) error + Create(ctx context.Context, group *Group) error + GetByID(ctx context.Context, id int64) (*Group, error) + Update(ctx context.Context, group *Group) error Delete(ctx context.Context, id int64) error DeleteCascade(ctx context.Context, id int64) ([]int64, error) - List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) - ListActive(ctx context.Context) ([]model.Group, error) - ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) + List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) + ListActive(ctx context.Context) ([]Group, error) + ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ExistsByName(ctx context.Context, name string) (bool, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error) @@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService { } // Create 创建分组 -func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) { +func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) { // 检查名称是否已存在 exists, err := s.groupRepo.ExistsByName(ctx, req.Name) if err != nil { @@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod } // 创建分组 - group := &model.Group{ - Name: req.Name, - Description: req.Description, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: model.StatusActive, + group := &Group{ + Name: req.Name, + Description: req.Description, + Platform: PlatformAnthropic, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, } if err := s.groupRepo.Create(ctx, group); err != nil { @@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod } // GetByID 根据ID获取分组 -func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { +func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get group: %w", err) @@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err } // List 获取分组列表 -func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { +func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { groups, pagination, err := s.groupRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list groups: %w", err) @@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar } // ListActive 获取活跃分组列表 -func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { +func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) { groups, err := s.groupRepo.ListActive(ctx) if err != nil { return nil, fmt.Errorf("list active groups: %w", err) @@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { } // Update 更新分组 -func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { +func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get group: %w", err) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 8eb476c2..f4c149ac 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -6,7 +6,6 @@ import ( "log" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr } // RefreshAccountToken refreshes token for an account -func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) { +func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) { refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, fmt.Errorf("no refresh token available") diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ca3c2c36..aab7837b 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -16,7 +16,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/gin-gonic/gin" ) @@ -119,12 +118,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { } // SelectAccount selects an OpenAI account with sticky session support -func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) { +func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") } // SelectAccountForModel selects an account supporting the requested model -func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { +func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { // 1. Check sticky session if sessionHash != "" { accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) @@ -139,19 +138,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI } // 2. Get schedulable OpenAI accounts - var accounts []model.Account + var accounts []Account var err error if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } // 3. Select by priority + LRU - var selected *model.Account + var selected *Account for i := range accounts { acc := &accounts[i] // Check model support @@ -189,15 +188,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI } // GetAccessToken gets the access token for an OpenAI account -func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) { +func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { - case model.AccountTypeOAuth: + case AccountTypeOAuth: accessToken := account.GetOpenAIAccessToken() if accessToken == "" { return "", "", errors.New("access_token not found in credentials") } return accessToken, "oauth", nil - case model.AccountTypeApiKey: + case AccountTypeApiKey: apiKey := account.GetOpenAIApiKey() if apiKey == "" { return "", "", errors.New("api_key not found in credentials") @@ -209,7 +208,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode } // Forward forwards request to OpenAI API -func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) { +func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() // Parse request body once (avoid multiple parse/serialize cycles) @@ -234,7 +233,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // For OAuth accounts using ChatGPT internal API, add store: false - if account.Type == model.AccountTypeOAuth { + if account.Type == AccountTypeOAuth { reqBody["store"] = false bodyModified = true } @@ -296,7 +295,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) - if account.Type == model.AccountTypeOAuth { + if account.Type == AccountTypeOAuth { if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil { s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) } @@ -312,14 +311,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) { +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { - case model.AccountTypeOAuth: + case AccountTypeOAuth: // OAuth accounts use ChatGPT internal API targetURL = chatgptCodexURL - case model.AccountTypeApiKey: + case AccountTypeApiKey: // API Key accounts use Platform API or custom base URL baseURL := account.GetOpenAIBaseURL() if baseURL != "" { @@ -340,7 +339,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("authorization", "Bearer "+token) // Set headers specific to OAuth accounts (ChatGPT internal API) - if account.Type == model.AccountTypeOAuth { + if account.Type == AccountTypeOAuth { // Required: set Host for ChatGPT API (must use req.Host, not Header.Set) req.Host = "chatgpt.com" // Required: set chatgpt-account-id header @@ -380,7 +379,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. return req, nil } -func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) { +func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(resp.Body) // Check custom error codes @@ -436,7 +435,7 @@ type openaiStreamingResult struct { firstTokenMs *int } -func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { +func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { // Set SSE response headers c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") @@ -552,7 +551,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { } } -func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { body, err := io.ReadAll(resp.Body) if err != nil { return nil, err @@ -618,10 +617,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult - ApiKey *model.ApiKey - User *model.User - Account *model.Account - Subscription *model.UserSubscription + ApiKey *ApiKey + User *User + Account *Account + Subscription *UserSubscription } // RecordUsage records usage and deducts balance @@ -660,14 +659,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Determine billing type isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := model.BillingTypeBalance + billingType := BillingTypeBalance if isSubscriptionBilling { - billingType = model.BillingTypeSubscription + billingType = BillingTypeSubscription } // Create usage log durationMs := int(result.Duration.Milliseconds()) - usageLog := &model.UsageLog{ + usageLog := &UsageLog{ UserID: user.ID, ApiKeyID: apiKey.ID, AccountID: account.ID, diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 420c755c..182e08fe 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -5,7 +5,6 @@ import ( "fmt" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri } // RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) { +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { if !account.IsOpenAI() { return nil, fmt.Errorf("account is not an OpenAI account") } diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go new file mode 100644 index 00000000..768e2a0a --- /dev/null +++ b/backend/internal/service/proxy.go @@ -0,0 +1,35 @@ +package service + +import ( + "fmt" + "time" +) + +type Proxy struct { + ID int64 + Name string + Protocol string + Host string + Port int + Username string + Password string + Status string + CreatedAt time.Time + UpdatedAt time.Time +} + +func (p *Proxy) IsActive() bool { + return p.Status == StatusActive +} + +func (p *Proxy) URL() string { + if p.Username != "" && p.Password != "" { + return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port) + } + return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port) +} + +type ProxyWithAccountCount struct { + Proxy + AccountCount int64 +} diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index 28ade11f..c074b13d 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -5,7 +5,6 @@ import ( "fmt" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -14,15 +13,15 @@ var ( ) type ProxyRepository interface { - Create(ctx context.Context, proxy *model.Proxy) error - GetByID(ctx context.Context, id int64) (*model.Proxy, error) - Update(ctx context.Context, proxy *model.Proxy) error + Create(ctx context.Context, proxy *Proxy) error + GetByID(ctx context.Context, id int64) (*Proxy, error) + Update(ctx context.Context, proxy *Proxy) error Delete(ctx context.Context, id int64) error - List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) - ListActive(ctx context.Context) ([]model.Proxy, error) - ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) + List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) + ListActive(ctx context.Context) ([]Proxy, error) + ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) @@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService { } // Create 创建代理 -func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) { +func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) { // 创建代理 - proxy := &model.Proxy{ + proxy := &Proxy{ Name: req.Name, Protocol: req.Protocol, Host: req.Host, Port: req.Port, Username: req.Username, Password: req.Password, - Status: model.StatusActive, + Status: StatusActive, } if err := s.proxyRepo.Create(ctx, proxy); err != nil { @@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod } // GetByID 根据ID获取代理 -func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { +func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get proxy: %w", err) @@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err } // List 获取代理列表 -func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { +func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { proxies, pagination, err := s.proxyRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list proxies: %w", err) @@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar } // ListActive 获取活跃代理列表 -func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { +func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) { proxies, err := s.proxyRepo.ListActive(ctx) if err != nil { return nil, fmt.Errorf("list active proxies: %w", err) @@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { } // Update 更新代理 -func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { +func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get proxy: %w", err) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 7bee7907..27b6d3c8 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -8,7 +8,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" ) // RateLimitService 处理限流和过载状态管理 @@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 -func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { +func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { // apikey 类型账号:检查自定义错误码配置 // 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载) if !account.ShouldHandleErrorCode(statusCode) { @@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod } // handleAuthError 处理认证类错误(401/403),停止账号调度 -func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) { +func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { log.Printf("SetError failed for account %d: %v", account.ID, err) return @@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A // handle429 处理429限流错误 // 解析响应头获取重置时间,标记账号为限流状态 -func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) { +func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) { // 解析重置时间戳 resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") if resetTimestamp == "" { @@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account // handle529 处理529过载错误 // 根据配置设置过载冷却时间 -func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) { +func (s *RateLimitService) handle529(ctx context.Context, account *Account) { cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes if cooldownMinutes <= 0 { cooldownMinutes = 10 // 默认10分钟 @@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account } // UpdateSessionWindow 从成功响应更新5h窗口状态 -func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) { +func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) { status := headers.Get("anthropic-ratelimit-unified-5h-status") if status == "" { return diff --git a/backend/internal/service/redeem_code.go b/backend/internal/service/redeem_code.go new file mode 100644 index 00000000..a66b53ba --- /dev/null +++ b/backend/internal/service/redeem_code.go @@ -0,0 +1,41 @@ +package service + +import ( + "crypto/rand" + "encoding/hex" + "time" +) + +type RedeemCode struct { + ID int64 + Code string + Type string + Value float64 + Status string + UsedBy *int64 + UsedAt *time.Time + Notes string + CreatedAt time.Time + + GroupID *int64 + ValidityDays int + + User *User + Group *Group +} + +func (r *RedeemCode) IsUsed() bool { + return r.Status == StatusUsed +} + +func (r *RedeemCode) CanUse() bool { + return r.Status == StatusUnused +} + +func GenerateRedeemCode() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 591d1555..144f2c50 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -10,7 +10,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/redis/go-redis/v9" ) @@ -39,17 +38,17 @@ type RedeemCache interface { } type RedeemCodeRepository interface { - Create(ctx context.Context, code *model.RedeemCode) error - CreateBatch(ctx context.Context, codes []model.RedeemCode) error - GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) - GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) - Update(ctx context.Context, code *model.RedeemCode) error + Create(ctx context.Context, code *RedeemCode) error + CreateBatch(ctx context.Context, codes []RedeemCode) error + GetByID(ctx context.Context, id int64) (*RedeemCode, error) + GetByCode(ctx context.Context, code string) (*RedeemCode, error) + Update(ctx context.Context, code *RedeemCode) error Delete(ctx context.Context, id int64) error Use(ctx context.Context, id, userID int64) error - List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) - ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) + List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) + ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) } // GenerateCodesRequest 生成兑换码请求 @@ -116,7 +115,7 @@ func (s *RedeemService) GenerateRandomCode() (string, error) { } // GenerateCodes 批量生成兑换码 -func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) { +func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) { if req.Count <= 0 { return nil, errors.New("count must be greater than 0") } @@ -131,21 +130,21 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ codeType := req.Type if codeType == "" { - codeType = model.RedeemTypeBalance + codeType = RedeemTypeBalance } - codes := make([]model.RedeemCode, 0, req.Count) + codes := make([]RedeemCode, 0, req.Count) for i := 0; i < req.Count; i++ { code, err := s.GenerateRandomCode() if err != nil { return nil, fmt.Errorf("generate code: %w", err) } - codes = append(codes, model.RedeemCode{ + codes = append(codes, RedeemCode{ Code: code, Type: codeType, Value: req.Value, - Status: model.StatusUnused, + Status: StatusUnused, }) } @@ -210,7 +209,7 @@ func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) { } // Redeem 使用兑换码 -func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) { +func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) { // 检查限流 if err := s.checkRedeemRateLimit(ctx, userID); err != nil { return nil, err @@ -239,7 +238,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( } // 验证兑换码类型的前置条件 - if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil { + if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil { return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id") } @@ -261,7 +260,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 执行兑换逻辑(兑换码已被锁定,此时可安全操作) switch redeemCode.Type { - case model.RedeemTypeBalance: + case RedeemTypeBalance: // 增加用户余额 if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil { return nil, fmt.Errorf("update user balance: %w", err) @@ -275,13 +274,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( }() } - case model.RedeemTypeConcurrency: + case RedeemTypeConcurrency: // 增加用户并发数 if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil { return nil, fmt.Errorf("update user concurrency: %w", err) } - case model.RedeemTypeSubscription: + case RedeemTypeSubscription: validityDays := redeemCode.ValidityDays if validityDays <= 0 { validityDays = 30 @@ -320,7 +319,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( } // GetByID 根据ID获取兑换码 -func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { +func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get redeem code: %w", err) @@ -329,7 +328,7 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod } // GetByCode 根据Code获取兑换码 -func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { +func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) { redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { return nil, fmt.Errorf("get redeem code: %w", err) @@ -338,7 +337,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede } // List 获取兑换码列表(管理员功能) -func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { +func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { codes, pagination, err := s.redeemRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list redeem codes: %w", err) @@ -383,7 +382,7 @@ func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) { } // GetUserHistory 获取用户的兑换历史 -func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { +func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) { codes, err := s.redeemRepo.ListByUser(ctx, userID, limit) if err != nil { return nil, fmt.Errorf("get user redeem history: %w", err) diff --git a/backend/internal/service/setting.go b/backend/internal/service/setting.go new file mode 100644 index 00000000..eef6bcc5 --- /dev/null +++ b/backend/internal/service/setting.go @@ -0,0 +1,10 @@ +package service + +import "time" + +type Setting struct { + ID int64 + Key string + Value string + UpdatedAt time.Time +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index cb38203c..0ffe991d 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -10,7 +10,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" ) var ( @@ -19,7 +18,7 @@ var ( ) type SettingRepository interface { - Get(ctx context.Context, key string) (*model.Setting, error) + Get(ctx context.Context, key string) (*Setting, error) GetValue(ctx context.Context, key string) (string, error) Set(ctx context.Context, key, value string) error GetMultiple(ctx context.Context, keys []string) (map[string]string, error) @@ -43,7 +42,7 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti } // GetAllSettings 获取所有系统设置 -func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) { +func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) if err != nil { return nil, fmt.Errorf("get all settings: %w", err) @@ -53,18 +52,18 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSetti } // GetPublicSettings 获取公开设置(无需登录) -func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) { +func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) { keys := []string{ - model.SettingKeyRegistrationEnabled, - model.SettingKeyEmailVerifyEnabled, - model.SettingKeyTurnstileEnabled, - model.SettingKeyTurnstileSiteKey, - model.SettingKeySiteName, - model.SettingKeySiteLogo, - model.SettingKeySiteSubtitle, - model.SettingKeyApiBaseUrl, - model.SettingKeyContactInfo, - model.SettingKeyDocUrl, + SettingKeyRegistrationEnabled, + SettingKeyEmailVerifyEnabled, + SettingKeyTurnstileEnabled, + SettingKeyTurnstileSiteKey, + SettingKeySiteName, + SettingKeySiteLogo, + SettingKeySiteSubtitle, + SettingKeyApiBaseUrl, + SettingKeyContactInfo, + SettingKeyDocUrl, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -72,64 +71,64 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSe return nil, fmt.Errorf("get public settings: %w", err) } - return &model.PublicSettings{ - RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true", - TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"), - SiteLogo: settings[model.SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - ApiBaseUrl: settings[model.SettingKeyApiBaseUrl], - ContactInfo: settings[model.SettingKeyContactInfo], - DocUrl: settings[model.SettingKeyDocUrl], + return &PublicSettings{ + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + ApiBaseUrl: settings[SettingKeyApiBaseUrl], + ContactInfo: settings[SettingKeyContactInfo], + DocUrl: settings[SettingKeyDocUrl], }, nil } // UpdateSettings 更新系统设置 -func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error { +func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { updates := make(map[string]string) // 注册设置 - updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) - updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) + updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) + updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) // 邮件服务设置(只有非空才更新密码) - updates[model.SettingKeySmtpHost] = settings.SmtpHost - updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort) - updates[model.SettingKeySmtpUsername] = settings.SmtpUsername + updates[SettingKeySmtpHost] = settings.SmtpHost + updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort) + updates[SettingKeySmtpUsername] = settings.SmtpUsername if settings.SmtpPassword != "" { - updates[model.SettingKeySmtpPassword] = settings.SmtpPassword + updates[SettingKeySmtpPassword] = settings.SmtpPassword } - updates[model.SettingKeySmtpFrom] = settings.SmtpFrom - updates[model.SettingKeySmtpFromName] = settings.SmtpFromName - updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS) + updates[SettingKeySmtpFrom] = settings.SmtpFrom + updates[SettingKeySmtpFromName] = settings.SmtpFromName + updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS) // Cloudflare Turnstile 设置(只有非空才更新密钥) - updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) - updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey + updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) + updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey if settings.TurnstileSecretKey != "" { - updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey + updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey } // OEM设置 - updates[model.SettingKeySiteName] = settings.SiteName - updates[model.SettingKeySiteLogo] = settings.SiteLogo - updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle - updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl - updates[model.SettingKeyContactInfo] = settings.ContactInfo - updates[model.SettingKeyDocUrl] = settings.DocUrl + updates[SettingKeySiteName] = settings.SiteName + updates[SettingKeySiteLogo] = settings.SiteLogo + updates[SettingKeySiteSubtitle] = settings.SiteSubtitle + updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl + updates[SettingKeyContactInfo] = settings.ContactInfo + updates[SettingKeyDocUrl] = settings.DocUrl // 默认配置 - updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) - updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) + updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) return s.settingRepo.SetMultiple(ctx, updates) } // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled) + value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err != nil { // 默认开放注册 return true @@ -139,7 +138,7 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { // IsEmailVerifyEnabled 检查是否开启邮件验证 func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled) + value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled) if err != nil { return false } @@ -148,7 +147,7 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { // GetSiteName 获取网站名称 func (s *SettingService) GetSiteName(ctx context.Context) string { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName) + value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) if err != nil || value == "" { return "Sub2API" } @@ -157,7 +156,7 @@ func (s *SettingService) GetSiteName(ctx context.Context) string { // GetDefaultConcurrency 获取默认并发量 func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency) + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultConcurrency) if err != nil { return s.cfg.Default.UserConcurrency } @@ -169,7 +168,7 @@ func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int { // GetDefaultBalance 获取默认余额 func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance) + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultBalance) if err != nil { return s.cfg.Default.UserBalance } @@ -182,7 +181,7 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 - _, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled) + _, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err == nil { // 已有设置,不需要初始化 return nil @@ -193,62 +192,62 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - model.SettingKeyRegistrationEnabled: "true", - model.SettingKeyEmailVerifyEnabled: "false", - model.SettingKeySiteName: "Sub2API", - model.SettingKeySiteLogo: "", - model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - model.SettingKeySmtpPort: "587", - model.SettingKeySmtpUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeySmtpPort: "587", + SettingKeySmtpUseTLS: "false", } return s.settingRepo.SetMultiple(ctx, defaults) } // parseSettings 解析设置到结构体 -func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings { - result := &model.SystemSettings{ - RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true", - SmtpHost: settings[model.SettingKeySmtpHost], - SmtpUsername: settings[model.SettingKeySmtpUsername], - SmtpFrom: settings[model.SettingKeySmtpFrom], - SmtpFromName: settings[model.SettingKeySmtpFromName], - SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true", - TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"), - SiteLogo: settings[model.SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - ApiBaseUrl: settings[model.SettingKeyApiBaseUrl], - ContactInfo: settings[model.SettingKeyContactInfo], - DocUrl: settings[model.SettingKeyDocUrl], +func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { + result := &SystemSettings{ + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", + SmtpHost: settings[SettingKeySmtpHost], + SmtpUsername: settings[SettingKeySmtpUsername], + SmtpFrom: settings[SettingKeySmtpFrom], + SmtpFromName: settings[SettingKeySmtpFromName], + SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + ApiBaseUrl: settings[SettingKeyApiBaseUrl], + ContactInfo: settings[SettingKeyContactInfo], + DocUrl: settings[SettingKeyDocUrl], } // 解析整数类型 - if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil { + if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil { result.SmtpPort = port } else { result.SmtpPort = 587 } - if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil { + if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil { result.DefaultConcurrency = concurrency } else { result.DefaultConcurrency = s.cfg.Default.UserConcurrency } // 解析浮点数类型 - if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil { + if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { result.DefaultBalance = balance } else { result.DefaultBalance = s.cfg.Default.UserBalance } // 敏感信息直接返回,方便测试连接时使用 - result.SmtpPassword = settings[model.SettingKeySmtpPassword] - result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey] + result.SmtpPassword = settings[SettingKeySmtpPassword] + result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] return result } @@ -263,7 +262,7 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def // IsTurnstileEnabled 检查是否启用 Turnstile 验证 func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled) + value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled) if err != nil { return false } @@ -272,7 +271,7 @@ func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool { // GetTurnstileSecretKey 获取 Turnstile Secret Key func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string { - value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey) + value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileSecretKey) if err != nil { return "" } @@ -287,10 +286,10 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error return "", fmt.Errorf("generate random bytes: %w", err) } - key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes) + key := AdminApiKeyPrefix + hex.EncodeToString(bytes) // 存储到 settings 表 - if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil { + if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil { return "", fmt.Errorf("save admin api key: %w", err) } @@ -300,7 +299,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error // GetAdminApiKeyStatus 获取管理员 API Key 状态 // 返回脱敏的 key、是否存在、错误 func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { - key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) + key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey) if err != nil { if errors.Is(err, ErrSettingNotFound) { return "", false, nil @@ -324,7 +323,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st // GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用) // 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { - key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) + key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey) if err != nil { if errors.Is(err, ErrSettingNotFound) { return "", nil // 未配置,返回空字符串 @@ -336,5 +335,5 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { // DeleteAdminApiKey 删除管理员 API Key func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error { - return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey) + return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey) } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go new file mode 100644 index 00000000..cb9751d1 --- /dev/null +++ b/backend/internal/service/settings_view.go @@ -0,0 +1,42 @@ +package service + +type SystemSettings struct { + RegistrationEnabled bool + EmailVerifyEnabled bool + + SmtpHost string + SmtpPort int + SmtpUsername string + SmtpPassword string + SmtpFrom string + SmtpFromName string + SmtpUseTLS bool + + TurnstileEnabled bool + TurnstileSiteKey string + TurnstileSecretKey string + + SiteName string + SiteLogo string + SiteSubtitle string + ApiBaseUrl string + ContactInfo string + DocUrl string + + DefaultConcurrency int + DefaultBalance float64 +} + +type PublicSettings struct { + RegistrationEnabled bool + EmailVerifyEnabled bool + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + ApiBaseUrl string + ContactInfo string + DocUrl string + Version string +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index f1ff6a2d..5b957094 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -7,7 +7,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -48,7 +47,7 @@ type AssignSubscriptionInput struct { } // AssignSubscription 分配订阅给用户(不允许重复分配) -func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) { +func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { // 检查分组是否存在且为订阅类型 group, err := s.groupRepo.GetByID(ctx, input.GroupID) if err != nil { @@ -91,7 +90,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass // - 已过期:从当前时间开始计算新的过期时间,并激活订阅 // // 如果没有订阅:创建新订阅 -func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) { +func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { // 检查分组是否存在且为订阅类型 group, err := s.groupRepo.GetByID(ctx, input.GroupID) if err != nil { @@ -132,8 +131,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 如果订阅已过期或被暂停,恢复为active状态 - if existingSub.Status != model.SubscriptionStatusActive { - if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil { + if existingSub.Status != SubscriptionStatusActive { + if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil { return nil, false, fmt.Errorf("update subscription status: %w", err) } } @@ -185,19 +184,19 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // createSubscription 创建新订阅(内部方法) -func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) { +func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { validityDays := input.ValidityDays if validityDays <= 0 { validityDays = 30 } now := time.Now() - sub := &model.UserSubscription{ + sub := &UserSubscription{ UserID: input.UserID, GroupID: input.GroupID, StartsAt: now, ExpiresAt: now.AddDate(0, 0, validityDays), - Status: model.SubscriptionStatusActive, + Status: SubscriptionStatusActive, AssignedAt: now, Notes: input.Notes, CreatedAt: now, @@ -229,14 +228,14 @@ type BulkAssignSubscriptionInput struct { type BulkAssignResult struct { SuccessCount int FailedCount int - Subscriptions []model.UserSubscription + Subscriptions []UserSubscription Errors []string } // BulkAssignSubscription 批量分配订阅 func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) { result := &BulkAssignResult{ - Subscriptions: make([]model.UserSubscription, 0), + Subscriptions: make([]UserSubscription, 0), Errors: make([]string, 0), } @@ -286,7 +285,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti } // ExtendSubscription 延长订阅 -func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) { +func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) { sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) if err != nil { return nil, ErrSubscriptionNotFound @@ -299,8 +298,8 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti } // 如果订阅已过期,恢复为active状态 - if sub.Status == model.SubscriptionStatusExpired { - if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil { + if sub.Status == SubscriptionStatusExpired { + if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil { return nil, err } } @@ -319,12 +318,12 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti } // GetByID 根据ID获取订阅 -func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { +func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) { return s.userSubRepo.GetByID(ctx, id) } // GetActiveSubscription 获取用户对特定分组的有效订阅 -func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { +func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) { sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) if err != nil { return nil, ErrSubscriptionNotFound @@ -333,7 +332,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, } // ListUserSubscriptions 获取用户的所有订阅 -func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { +func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) { subs, err := s.userSubRepo.ListByUserID(ctx, userID) if err != nil { return nil, err @@ -343,7 +342,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID } // ListActiveUserSubscriptions 获取用户的所有有效订阅 -func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { +func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) { subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID) if err != nil { return nil, err @@ -353,7 +352,7 @@ func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, u } // ListGroupSubscriptions 获取分组的所有订阅 -func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) { +func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params) if err != nil { @@ -364,7 +363,7 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI } // List 获取所有订阅(分页,支持筛选) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status) if err != nil { @@ -376,7 +375,7 @@ func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, user // normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库) // 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据 -func normalizeExpiredWindows(subs []model.UserSubscription) { +func normalizeExpiredWindows(subs []UserSubscription) { for i := range subs { sub := &subs[i] // 日窗口过期:清零展示数据 @@ -403,7 +402,7 @@ func startOfDay(t time.Time) time.Time { } // CheckAndActivateWindow 检查并激活窗口(首次使用时) -func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error { +func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error { if sub.IsWindowActivated() { return nil } @@ -414,7 +413,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m } // CheckAndResetWindows 检查并重置过期的窗口 -func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error { +func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error { // 使用当天零点作为新窗口起始时间 windowStart := startOfDay(time.Now()) needsInvalidateCache := false @@ -458,7 +457,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod } // CheckUsageLimits 检查使用限额(返回错误如果超限) -func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error { +func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error { if !sub.CheckDailyLimit(group, additionalCost) { return ErrDailyLimitExceeded } @@ -620,16 +619,16 @@ func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (i } // ValidateSubscription 验证订阅是否有效 -func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error { - if sub.Status == model.SubscriptionStatusExpired { +func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error { + if sub.Status == SubscriptionStatusExpired { return ErrSubscriptionExpired } - if sub.Status == model.SubscriptionStatusSuspended { + if sub.Status == SubscriptionStatusSuspended { return ErrSubscriptionSuspended } if sub.IsExpired() { // 更新状态 - _ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired) + _ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired) return ErrSubscriptionExpired } return nil diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 24ef7b8e..f93d09ea 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -8,7 +8,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/model" ) // TokenRefreshService OAuth token自动刷新服务 @@ -142,19 +141,19 @@ func (s *TokenRefreshService) processRefresh() { // listActiveAccounts 获取所有active状态的账号 // 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的) -func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) { +func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) { return s.accountRepo.ListActive(ctx) } // refreshWithRetry 带重试的刷新 -func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error { +func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error { var lastErr error for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { newCredentials, err := refresher.Refresh(ctx, account) if err == nil { // 刷新成功,更新账号credentials - account.Credentials = model.JSONB(newCredentials) + account.Credentials = newCredentials if err := s.accountRepo.Update(ctx, account); err != nil { return fmt.Errorf("failed to save credentials: %w", err) } diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 61a25aac..8857a416 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -4,22 +4,20 @@ import ( "context" "strconv" "time" - - "github.com/Wei-Shaw/sub2api/internal/model" ) // TokenRefresher 定义平台特定的token刷新策略接口 // 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini) type TokenRefresher interface { // CanRefresh 检查此刷新器是否能处理指定账号 - CanRefresh(account *model.Account) bool + CanRefresh(account *Account) bool // NeedsRefresh 检查账号的token是否需要刷新 - NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool + NeedsRefresh(account *Account, refreshWindow time.Duration) bool // Refresh 执行token刷新,返回更新后的credentials // 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段 - Refresh(ctx context.Context, account *model.Account) (map[string]any, error) + Refresh(ctx context.Context, account *Account) (map[string]any, error) } // ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新 @@ -37,14 +35,14 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher { // CanRefresh 检查是否能处理此账号 // 只处理 anthropic 平台的 oauth 类型账号 // setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新 -func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool { - return account.Platform == model.PlatformAnthropic && - account.Type == model.AccountTypeOAuth +func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { + return account.Platform == PlatformAnthropic && + account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 -func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { +func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAtStr := account.GetCredential("expires_at") if expiresAtStr == "" { return false @@ -61,7 +59,7 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo // Refresh 执行token刷新 // 保留原有credentials中的所有字段,只更新token相关字段 -func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { +func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account) if err != nil { return nil, err @@ -103,14 +101,14 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAIToke // CanRefresh 检查是否能处理此账号 // 只处理 openai 平台的 oauth 类型账号 -func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool { - return account.Platform == model.PlatformOpenAI && - account.Type == model.AccountTypeOAuth +func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { + return account.Platform == PlatformOpenAI && + account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 -func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { +func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAt := account.GetOpenAITokenExpiresAt() if expiresAt == nil { return false @@ -121,7 +119,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindo // Refresh 执行token刷新 // 保留原有credentials中的所有字段,只更新token相关字段 -func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { +func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) { tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account) if err != nil { return nil, err diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go new file mode 100644 index 00000000..e822cd95 --- /dev/null +++ b/backend/internal/service/usage_log.go @@ -0,0 +1,53 @@ +package service + +import "time" + +const ( + BillingTypeBalance int8 = 0 // 钱包余额 + BillingTypeSubscription int8 = 1 // 订阅套餐 +) + +type UsageLog struct { + ID int64 + UserID int64 + ApiKeyID int64 + AccountID int64 + RequestID string + Model string + + GroupID *int64 + SubscriptionID *int64 + + InputTokens int + OutputTokens int + CacheCreationTokens int + CacheReadTokens int + + CacheCreation5mTokens int + CacheCreation1hTokens int + + InputCost float64 + OutputCost float64 + CacheCreationCost float64 + CacheReadCost float64 + TotalCost float64 + ActualCost float64 + RateMultiplier float64 + + BillingType int8 + Stream bool + DurationMs *int + FirstTokenMs *int + + CreatedAt time.Time + + User *User + ApiKey *ApiKey + Account *Account + Group *Group + Subscription *UserSubscription +} + +func (u *UsageLog) TotalTokens() int { + return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens +} diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index c574981d..2ccad4ff 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -6,7 +6,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) @@ -66,7 +65,7 @@ func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *Usa } // Create 创建使用日志 -func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) { +func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) { // 验证用户存在 _, err := s.userRepo.GetByID(ctx, req.UserID) if err != nil { @@ -74,7 +73,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // 创建使用日志 - usageLog := &model.UsageLog{ + usageLog := &UsageLog{ UserID: req.UserID, ApiKeyID: req.ApiKeyID, AccountID: req.AccountID, @@ -112,7 +111,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // GetByID 根据ID获取使用日志 -func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { +func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get usage log: %w", err) @@ -121,7 +120,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, } // ListByUser 获取用户的使用日志列表 -func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -130,7 +129,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi } // ListByApiKey 获取API Key的使用日志列表 -func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -139,7 +138,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params } // ListByAccount 获取账号的使用日志列表 -func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -243,7 +242,7 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int } // calculateStats 计算统计数据 -func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats { +func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats { stats := &UsageStats{} for _, log := range logs { @@ -313,7 +312,7 @@ func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs [ } // ListWithFilters lists usage logs with admin filters. -func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) { logs, result, err := s.usageRepo.ListWithFilters(ctx, params, filters) if err != nil { return nil, nil, fmt.Errorf("list usage logs with filters: %w", err) diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go new file mode 100644 index 00000000..70995b5d --- /dev/null +++ b/backend/internal/service/user.go @@ -0,0 +1,63 @@ +package service + +import ( + "time" + + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int64 + Email string + Username string + Wechat string + Notes string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + CreatedAt time.Time + UpdatedAt time.Time + + ApiKeys []ApiKey + Subscriptions []UserSubscription +} + +func (u *User) IsAdmin() bool { + return u.Role == RoleAdmin +} + +func (u *User) IsActive() bool { + return u.Status == StatusActive +} + +// CanBindGroup checks whether a user can bind to a given group. +// For standard groups: +// - If AllowedGroups is non-empty, only allow binding to IDs in that list. +// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. +func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { + if len(u.AllowedGroups) > 0 { + for _, id := range u.AllowedGroups { + if id == groupID { + return true + } + } + return false + } + return !isExclusive +} + +func (u *User) SetPassword(password string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + u.PasswordHash = string(hash) + return nil +} + +func (u *User) CheckPassword(password string) bool { + return bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil +} diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3830bc67..3ff47e7d 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -5,9 +5,7 @@ import ( "fmt" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "golang.org/x/crypto/bcrypt" ) var ( @@ -17,15 +15,15 @@ var ( ) type UserRepository interface { - Create(ctx context.Context, user *model.User) error - GetByID(ctx context.Context, id int64) (*model.User, error) - GetByEmail(ctx context.Context, email string) (*model.User, error) - GetFirstAdmin(ctx context.Context) (*model.User, error) - Update(ctx context.Context, user *model.User) error + Create(ctx context.Context, user *User) error + GetByID(ctx context.Context, id int64) (*User, error) + GetByEmail(ctx context.Context, email string) (*User, error) + GetFirstAdmin(ctx context.Context) (*User, error) + Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error - List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error @@ -61,7 +59,7 @@ func NewUserService(userRepo UserRepository) *UserService { } // GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证) -func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) { +func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) { admin, err := s.userRepo.GetFirstAdmin(ctx) if err != nil { return nil, fmt.Errorf("get first admin: %w", err) @@ -70,7 +68,7 @@ func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) { } // GetProfile 获取用户资料 -func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) { +func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) @@ -79,7 +77,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User } // UpdateProfile 更新用户资料 -func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) { +func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) @@ -125,18 +123,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan } // 验证当前密码 - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil { + if !user.CheckPassword(req.CurrentPassword) { return ErrPasswordIncorrect } - // 生成新密码哈希 - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) - if err != nil { - return fmt.Errorf("hash password: %w", err) + if err := user.SetPassword(req.NewPassword); err != nil { + return fmt.Errorf("set password: %w", err) } - user.PasswordHash = string(hashedPassword) - if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } @@ -145,7 +139,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan } // GetByID 根据ID获取用户(管理员功能) -func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) { +func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, fmt.Errorf("get user: %w", err) @@ -154,7 +148,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error } // List 获取用户列表(管理员功能) -func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { +func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list users: %w", err) diff --git a/backend/internal/service/user_subscription.go b/backend/internal/service/user_subscription.go new file mode 100644 index 00000000..ec547d81 --- /dev/null +++ b/backend/internal/service/user_subscription.go @@ -0,0 +1,124 @@ +package service + +import "time" + +type UserSubscription struct { + ID int64 + UserID int64 + GroupID int64 + + StartsAt time.Time + ExpiresAt time.Time + Status string + + DailyWindowStart *time.Time + WeeklyWindowStart *time.Time + MonthlyWindowStart *time.Time + + DailyUsageUSD float64 + WeeklyUsageUSD float64 + MonthlyUsageUSD float64 + + AssignedBy *int64 + AssignedAt time.Time + Notes string + + CreatedAt time.Time + UpdatedAt time.Time + + User *User + Group *Group + AssignedByUser *User +} + +func (s *UserSubscription) IsActive() bool { + return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt) +} + +func (s *UserSubscription) IsExpired() bool { + return time.Now().After(s.ExpiresAt) +} + +func (s *UserSubscription) DaysRemaining() int { + if s.IsExpired() { + return 0 + } + return int(time.Until(s.ExpiresAt).Hours() / 24) +} + +func (s *UserSubscription) IsWindowActivated() bool { + return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil +} + +func (s *UserSubscription) NeedsDailyReset() bool { + if s.DailyWindowStart == nil { + return false + } + return time.Since(*s.DailyWindowStart) >= 24*time.Hour +} + +func (s *UserSubscription) NeedsWeeklyReset() bool { + if s.WeeklyWindowStart == nil { + return false + } + return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour +} + +func (s *UserSubscription) NeedsMonthlyReset() bool { + if s.MonthlyWindowStart == nil { + return false + } + return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour +} + +func (s *UserSubscription) DailyResetTime() *time.Time { + if s.DailyWindowStart == nil { + return nil + } + t := s.DailyWindowStart.Add(24 * time.Hour) + return &t +} + +func (s *UserSubscription) WeeklyResetTime() *time.Time { + if s.WeeklyWindowStart == nil { + return nil + } + t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour) + return &t +} + +func (s *UserSubscription) MonthlyResetTime() *time.Time { + if s.MonthlyWindowStart == nil { + return nil + } + t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour) + return &t +} + +func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool { + if !group.HasDailyLimit() { + return true + } + return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD +} + +func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool { + if !group.HasWeeklyLimit() { + return true + } + return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD +} + +func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool { + if !group.HasMonthlyLimit() { + return true + } + return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD +} + +func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) { + daily = s.CheckDailyLimit(group, additionalCost) + weekly = s.CheckWeeklyLimit(group, additionalCost) + monthly = s.CheckMonthlyLimit(group, additionalCost) + return +} diff --git a/backend/internal/service/user_subscription_port.go b/backend/internal/service/user_subscription_port.go index 615a4501..abf4dffd 100644 --- a/backend/internal/service/user_subscription_port.go +++ b/backend/internal/service/user_subscription_port.go @@ -4,22 +4,21 @@ import ( "context" "time" - "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) type UserSubscriptionRepository interface { - Create(ctx context.Context, sub *model.UserSubscription) error - GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) - GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) - GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) - Update(ctx context.Context, sub *model.UserSubscription) error + Create(ctx context.Context, sub *UserSubscription) error + GetByID(ctx context.Context, id int64) (*UserSubscription, error) + GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error) + GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error) + Update(ctx context.Context, sub *UserSubscription) error Delete(ctx context.Context, id int64) error - ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) - ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) - ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) - List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) + ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) + ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) + ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index edb426b9..387077bb 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -10,10 +10,10 @@ import ( "strconv" "time" - "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" - "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -271,8 +271,7 @@ func initializeDatabase(cfg *SetupConfig) error { } }() - // 使用 model 包的 AutoMigrate,确保模型定义统一 - return model.AutoMigrate(db) + return repository.AutoMigrate(db) } func createAdminUser(cfg *SetupConfig) error { @@ -299,29 +298,28 @@ func createAdminUser(cfg *SetupConfig) error { // Check if admin already exists var count int64 - db.Model(&model.User{}).Where("role = ?", "admin").Count(&count) + if err := db.Table("users").Where("role = ?", service.RoleAdmin).Count(&count).Error; err != nil { + return err + } if count > 0 { return nil // Admin already exists } - // Hash password - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost) - if err != nil { + admin := &service.User{ + Email: cfg.Admin.Email, + Role: service.RoleAdmin, + Status: service.StatusActive, + Balance: 0, + Concurrency: 5, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := admin.SetPassword(cfg.Admin.Password); err != nil { return err } - // Create admin user - admin := &model.User{ - Email: cfg.Admin.Email, - PasswordHash: string(hashedPassword), - Role: model.RoleAdmin, - Status: model.StatusActive, - Balance: 0, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - return db.Create(admin).Error + return repository.NewUserRepository(db).Create(context.Background(), admin) } func writeConfigFile(cfg *SetupConfig) error {