refactor: 调整项目结构为单向依赖

This commit is contained in:
Forest
2025-12-26 15:40:24 +08:00
parent b3463769dc
commit e5a77853b0
98 changed files with 5503 additions and 3352 deletions

View File

@@ -48,8 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService)
userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(authService, userService)
userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(db)
groupRepository := repository.NewGroupRepository(db)

View File

@@ -22,12 +22,14 @@ require (
golang.org/x/net v0.47.0
golang.org/x/term v0.37.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/datatypes v1.2.0
gorm.io/driver/postgres v1.5.4
gorm.io/gorm v1.25.5
)
require (
dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
@@ -57,6 +59,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
@@ -64,8 +67,8 @@ require (
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/icholy/digest v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
@@ -132,4 +135,5 @@ require (
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gorm.io/driver/mysql v1.5.2 // indirect
)

View File

@@ -1,5 +1,7 @@
dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
@@ -77,10 +79,17 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -104,10 +113,10 @@ github.com/imroc/req/v3 v3.56.0 h1:t6YdqqerYBXhZ9+VjqsQs5wlKxdUNEvsgBhxWc1AEEo=
github.com/imroc/req/v3 v3.56.0/go.mod h1:cUZSooE8hhzFNOrAbdxuemXDQxFXLQTnu3066jr7ZGk=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@@ -135,8 +144,12 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLgZiaenE=
github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
@@ -319,8 +332,17 @@ gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/datatypes v1.2.0 h1:5YT+eokWdIxhJgWHdrb2zYUimyk0+TaFth+7a0ybzco=
gorm.io/datatypes v1.2.0/go.mod h1:o1dh0ZvjIjhH/bngTpypG6lVRJ5chTBxE09FH/71k04=
gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs=
gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/driver/sqlserver v1.4.1 h1:t4r4r6Jam5E6ejqP7N82qAJIJAht27EGT41HyPfXRw0=
gorm.io/driver/sqlserver v1.4.1/go.mod h1:DJ4P+MeZbc5rvY58PnmN1Lnyvb5gw5NPzGshHDnJLig=
gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=

View File

@@ -3,7 +3,7 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -102,7 +102,7 @@ type BulkUpdateAccountsRequest struct {
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*model.Account
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
}
@@ -137,7 +137,7 @@ func (h *AccountHandler) List(c *gin.Context) {
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
result[i] = AccountWithConcurrency{
Account: &accounts[i],
Account: dto.AccountFromService(&accounts[i]),
CurrentConcurrency: concurrencyCounts[accounts[i].ID],
}
}
@@ -160,7 +160,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Create handles creating a new account
@@ -188,7 +188,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Update handles updating an account
@@ -222,7 +222,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// Delete handles deleting an account
@@ -425,7 +425,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// BatchCreate handles batch creating accounts
@@ -801,7 +801,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response.Success(c, account)
response.Success(c, dto.AccountFromService(account))
}
// GetAvailableModels handles getting available models for an account

View File

@@ -1,11 +1,12 @@
package admin
import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"strconv"
"time"
"github.com/gin-gonic/gin"
)

View File

@@ -3,7 +3,7 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -69,7 +69,11 @@ func (h *GroupHandler) List(c *gin.Context) {
return
}
response.Paginated(c, groups, total, page, pageSize)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Paginated(c, outGroups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
@@ -77,7 +81,7 @@ func (h *GroupHandler) List(c *gin.Context) {
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []model.Group
var groups []service.Group
var err error
if platform != "" {
@@ -91,7 +95,11 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
return
}
response.Success(c, groups)
outGroups := make([]dto.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
}
response.Success(c, outGroups)
}
// GetByID handles getting a group by ID
@@ -109,7 +117,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Create handles creating a new group
@@ -137,7 +145,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Update handles updating a group
@@ -172,7 +180,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
return
}
response.Success(c, group)
response.Success(c, dto.GroupFromService(group))
}
// Delete handles deleting a group
@@ -229,5 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return
}
response.Paginated(c, keys, total, page, pageSize)
outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -1,7 +1,7 @@
package admin
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -31,7 +31,28 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
return
}
response.Success(c, settings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
})
}
// UpdateSettingsRequest 更新设置请求
@@ -87,7 +108,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SmtpPort = 587
}
settings := &model.SystemSettings{
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
@@ -122,7 +143,28 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
response.Success(c, updatedSettings)
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
})
}
// TestSmtpRequest 测试SMTP连接请求

View File

@@ -3,9 +3,10 @@ package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -82,7 +83,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
@@ -100,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// GetProgress handles getting subscription usage progress
@@ -145,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// BulkAssign handles bulk assigning subscriptions to multiple users
@@ -196,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
response.Success(c, subscription)
response.Success(c, dto.UserSubscriptionFromService(subscription))
}
// Revoke handles revoking a subscription
@@ -234,7 +239,11 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
@@ -252,15 +261,18 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
return 0
}
return 0
return subject.UserID
}

View File

@@ -3,9 +3,10 @@ package handler
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -40,42 +41,34 @@ type UpdateAPIKeyRequest struct {
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -92,26 +85,20 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
}
// 验证所有权
if key.UserID != user.ID {
if key.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -126,27 +113,21 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -171,27 +152,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, key)
response.Success(c, dto.ApiKeyFromService(key))
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -201,7 +176,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
err = h.apiKeyService.Delete(c.Request.Context(), keyID, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -213,23 +188,21 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, groups)
out := make([]dto.Group, 0, len(groups))
for i := range groups {
out = append(out, *dto.GroupFromService(&groups[i]))
}
response.Success(c, out)
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -11,12 +12,14 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
userService *service.UserService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
return &AuthHandler{
authService: authService,
userService: userService,
}
}
@@ -49,9 +52,9 @@ type LoginRequest struct {
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *dto.User `json:"user"`
}
// Register handles user registration
@@ -80,7 +83,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
@@ -135,24 +138,24 @@ func (h *AuthHandler) Login(c *gin.Context) {
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
User: dto.UserFromService(user),
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
user, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, user)
response.Success(c, dto.UserFromService(user))
}

View File

@@ -0,0 +1,294 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
func UserFromServiceShallow(u *service.User) *User {
if u == nil {
return nil
}
return &User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func UserFromService(u *service.User) *User {
if u == nil {
return nil
}
out := UserFromServiceShallow(u)
if len(u.ApiKeys) > 0 {
out.ApiKeys = make([]ApiKey, 0, len(u.ApiKeys))
for i := range u.ApiKeys {
k := u.ApiKeys[i]
out.ApiKeys = append(out.ApiKeys, *ApiKeyFromService(&k))
}
}
if len(u.Subscriptions) > 0 {
out.Subscriptions = make([]UserSubscription, 0, len(u.Subscriptions))
for i := range u.Subscriptions {
s := u.Subscriptions[i]
out.Subscriptions = append(out.Subscriptions, *UserSubscriptionFromService(&s))
}
}
return out
}
func ApiKeyFromService(k *service.ApiKey) *ApiKey {
if k == nil {
return nil
}
return &ApiKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
func GroupFromServiceShallow(g *service.Group) *Group {
if g == nil {
return nil
}
return &Group{
ID: g.ID,
Name: g.Name,
Description: g.Description,
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
AccountCount: g.AccountCount,
}
}
func GroupFromService(g *service.Group) *Group {
if g == nil {
return nil
}
out := GroupFromServiceShallow(g)
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
for i := range g.AccountGroups {
ag := g.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
return out
}
func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
return &Account{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
}
func AccountFromService(a *service.Account) *Account {
if a == nil {
return nil
}
out := AccountFromServiceShallow(a)
out.Proxy = ProxyFromService(a.Proxy)
if len(a.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(a.AccountGroups))
for i := range a.AccountGroups {
ag := a.AccountGroups[i]
out.AccountGroups = append(out.AccountGroups, *AccountGroupFromService(&ag))
}
}
if len(a.Groups) > 0 {
out.Groups = make([]*Group, 0, len(a.Groups))
for _, g := range a.Groups {
out.Groups = append(out.Groups, GroupFromServiceShallow(g))
}
}
return out
}
func AccountGroupFromService(ag *service.AccountGroup) *AccountGroup {
if ag == nil {
return nil
}
return &AccountGroup{
AccountID: ag.AccountID,
GroupID: ag.GroupID,
Priority: ag.Priority,
CreatedAt: ag.CreatedAt,
Account: AccountFromServiceShallow(ag.Account),
Group: GroupFromServiceShallow(ag.Group),
}
}
func ProxyFromService(p *service.Proxy) *Proxy {
if p == nil {
return nil
}
return &Proxy{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWithAccountCount {
if p == nil {
return nil
}
return &ProxyWithAccountCount{
Proxy: *ProxyFromService(&p.Proxy),
AccountCount: p.AccountCount,
}
}
func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
if rc == nil {
return nil
}
return &RedeemCode{
ID: rc.ID,
Code: rc.Code,
Type: rc.Type,
Value: rc.Value,
Status: rc.Status,
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
Notes: rc.Notes,
CreatedAt: rc.CreatedAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
}
func UsageLogFromService(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
ID: l.ID,
UserID: l.UserID,
ApiKeyID: l.ApiKeyID,
AccountID: l.AccountID,
RequestID: l.RequestID,
Model: l.Model,
GroupID: l.GroupID,
SubscriptionID: l.SubscriptionID,
InputTokens: l.InputTokens,
OutputTokens: l.OutputTokens,
CacheCreationTokens: l.CacheCreationTokens,
CacheReadTokens: l.CacheReadTokens,
CacheCreation5mTokens: l.CacheCreation5mTokens,
CacheCreation1hTokens: l.CacheCreation1hTokens,
InputCost: l.InputCost,
OutputCost: l.OutputCost,
CacheCreationCost: l.CacheCreationCost,
CacheReadCost: l.CacheReadCost,
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
FirstTokenMs: l.FirstTokenMs,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
ApiKey: ApiKeyFromService(l.ApiKey),
Account: AccountFromService(l.Account),
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
}
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
}
return &Setting{
ID: s.ID,
Key: s.Key,
Value: s.Value,
UpdatedAt: s.UpdatedAt,
}
}
func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscription {
if sub == nil {
return nil
}
return &UserSubscription{
ID: sub.ID,
UserID: sub.UserID,
GroupID: sub.GroupID,
StartsAt: sub.StartsAt,
ExpiresAt: sub.ExpiresAt,
Status: sub.Status,
DailyWindowStart: sub.DailyWindowStart,
WeeklyWindowStart: sub.WeeklyWindowStart,
MonthlyWindowStart: sub.MonthlyWindowStart,
DailyUsageUSD: sub.DailyUsageUSD,
WeeklyUsageUSD: sub.WeeklyUsageUSD,
MonthlyUsageUSD: sub.MonthlyUsageUSD,
AssignedBy: sub.AssignedBy,
AssignedAt: sub.AssignedAt,
Notes: sub.Notes,
CreatedAt: sub.CreatedAt,
UpdatedAt: sub.UpdatedAt,
User: UserFromServiceShallow(sub.User),
Group: GroupFromServiceShallow(sub.Group),
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
}
}

View File

@@ -0,0 +1,43 @@
package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -0,0 +1,212 @@
package dto
import "time"
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
Username string `json:"username"`
Wechat string `json:"wechat"`
Notes string `json:"notes"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
Status string `json:"status"`
AllowedGroups []int64 `json:"allowed_groups"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ApiKeys []ApiKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
}
type ApiKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Group struct {
ID int64 `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
AccountCount int64 `json:"account_count,omitempty"`
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
RateLimitedAt *time.Time `json:"rate_limited_at"`
RateLimitResetAt *time.Time `json:"rate_limit_reset_at"`
OverloadUntil *time.Time `json:"overload_until"`
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
GroupIDs []int64 `json:"group_ids,omitempty"`
Groups []*Group `json:"groups,omitempty"`
}
type AccountGroup struct {
AccountID int64 `json:"account_id"`
GroupID int64 `json:"group_id"`
Priority int `json:"priority"`
CreatedAt time.Time `json:"created_at"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"-"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}
type RedeemCode struct {
ID int64 `json:"id"`
Code string `json:"code"`
Type string `json:"type"`
Value float64 `json:"value"`
Status string `json:"status"`
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
type UsageLog struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
GroupID *int64 `json:"group_id"`
SubscriptionID *int64 `json:"subscription_id"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
ApiKey *ApiKey `json:"api_key,omitempty"`
Account *Account `json:"account,omitempty"`
Group *Group `json:"group,omitempty"`
Subscription *UserSubscription `json:"subscription,omitempty"`
}
type Setting struct {
ID int64 `json:"id"`
Key string `json:"key"`
Value string `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type UserSubscription struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
GroupID int64 `json:"group_id"`
StartsAt time.Time `json:"starts_at"`
ExpiresAt time.Time `json:"expires_at"`
Status string `json:"status"`
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
DailyUsageUSD float64 `json:"daily_usage_usd"`
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
AssignedBy *int64 `json:"assigned_by"`
AssignedAt time.Time `json:"assigned_at"`
Notes string `json:"notes"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
AssignedByUser *User `json:"assigned_by_user,omitempty"`
}

View File

@@ -10,7 +10,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -47,7 +46,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
user, ok := middleware2.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -82,8 +81,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -92,10 +91,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -106,7 +105,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
@@ -133,7 +132,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -158,7 +157,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {
@@ -198,7 +197,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
return
}
user, ok := middleware2.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -223,7 +222,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
@@ -241,7 +240,7 @@ func (h *GatewayHandler) Usage(c *gin.Context) {
// 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
@@ -334,7 +333,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
user, ok := middleware2.GetUserFromContext(c)
_, ok = middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -366,7 +365,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
h.errorResponse(c, http.StatusForbidden, "billing_error", err.Error())
return
}

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -69,11 +68,11 @@ func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -83,17 +82,17 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted)
}
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
@@ -103,7 +102,7 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.

View File

@@ -46,7 +46,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
user, ok := middleware2.GetUserFromContext(c)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
@@ -94,8 +94,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
@@ -104,10 +104,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
// Ensure wait count is decremented when function exits
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), user.ID)
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, user, reqStream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
@@ -118,7 +118,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
@@ -138,7 +138,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name)
// 3. Acquire account concurrency slot
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account, reqStream, &streamStarted)
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWait(c, account.ID, account.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
@@ -163,7 +163,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
User: apiKey.User,
Account: account,
Subscription: subscription,
}); err != nil {

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -37,15 +38,9 @@ type RedeemResponse struct {
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -55,38 +50,36 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
result, err := h.redeemService.Redeem(c.Request.Context(), subject.UserID, req.Code)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
response.Success(c, dto.RedeemCodeFromService(result))
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), subject.UserID, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, codes)
out := make([]dto.RedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
}
response.Success(c, out)
}

View File

@@ -1,6 +1,7 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,17 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
return
}
settings.Version = h.version
response.Success(c, settings)
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocUrl: settings.DocUrl,
Version: h.version,
})
}

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,7 +26,7 @@ type SubscriptionSummaryItem struct {
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Subscription *dto.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
@@ -44,68 +45,58 @@ func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *S
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, subscriptions)
out := make([]dto.UserSubscription, 0, len(subscriptions))
for i := range subscriptions {
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
}
response.Success(c, out)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -120,7 +111,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Subscription: dto.UserSubscriptionFromService(sub),
Progress: progress,
})
}
@@ -131,20 +122,14 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -4,10 +4,11 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -30,15 +31,9 @@ func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -58,7 +53,7 @@ func (h *UsageHandler) List(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
@@ -67,35 +62,33 @@ func (h *UsageHandler) List(c *gin.Context) {
}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var records []service.UsageLog
var result *pagination.PaginationResult
var err error
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
records, result, err = h.usageService.ListByUser(c.Request.Context(), subject.UserID, params)
}
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, records, result.Total, page, pageSize)
out := make([]dto.UsageLog, 0, len(records))
for i := range records {
out = append(out, *dto.UsageLogFromService(&records[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -112,26 +105,20 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
}
// 验证所有权
if record.UserID != user.ID {
if record.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
response.Success(c, dto.UsageLogFromService(record))
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -149,7 +136,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
if apiKey.UserID != subject.UserID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
@@ -201,7 +188,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
if err != nil {
response.ErrorFrom(c, err)
@@ -245,19 +232,13 @@ func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -269,22 +250,16 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -301,21 +276,15 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -336,15 +305,9 @@ type BatchApiKeysUsageRequest struct {
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -360,7 +323,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -1,8 +1,9 @@
package handler
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -35,19 +36,13 @@ type UpdateProfileRequest struct {
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -56,21 +51,15 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
// 清空notes字段普通用户不应看到备注
userData.Notes = ""
response.Success(c, userData)
response.Success(c, dto.UserFromService(userData))
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -84,7 +73,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
err := h.userService.ChangePassword(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -96,15 +85,9 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// UpdateProfile handles updating user profile
// PUT /api/v1/users/me
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.InternalError(c, "Invalid user context")
response.Unauthorized(c, "User not authenticated")
return
}
@@ -118,7 +101,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
Username: req.Username,
Wechat: req.Wechat,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -127,5 +110,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
// 清空notes字段普通用户不应看到备注
updatedUser.Notes = ""
response.Success(c, updatedUser)
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@@ -2,8 +2,8 @@ package infrastructure
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/repository"
"gorm.io/driver/postgres"
"gorm.io/gorm"
@@ -30,7 +30,7 @@ func InitDB(cfg *config.Config) (*gorm.DB, error) {
// 自动迁移(始终执行,确保数据库结构与代码同步)
// GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的
if err := model.AutoMigrate(db); err != nil {
if err := repository.AutoMigrate(db); err != nil {
return nil, err
}

View File

@@ -1,415 +0,0 @@
package model
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
"gorm.io/gorm"
)
// JSONB 用于存储JSONB数据
type JSONB map[string]any
func (j JSONB) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
func (j *JSONB) Scan(value any) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, j)
}
type Account struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Platform string `gorm:"size:50;not null" json:"platform"` // anthropic/openai/gemini
Type string `gorm:"size:20;not null" json:"type"` // oauth/apikey
Credentials JSONB `gorm:"type:jsonb;default:'{}'" json:"credentials"` // 凭证(加密存储)
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
ProxyID *int64 `gorm:"index" json:"proxy_id"`
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100越小越高
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
ErrorMessage string `gorm:"type:text" json:"error_message"`
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 调度控制
Schedulable bool `gorm:"default:true;not null" json:"schedulable"`
// 限流状态 (429)
RateLimitedAt *time.Time `gorm:"index" json:"rate_limited_at"`
RateLimitResetAt *time.Time `gorm:"index" json:"rate_limit_reset_at"`
// 过载状态 (529)
OverloadUntil *time.Time `gorm:"index" json:"overload_until"`
// 5小时时间窗口
SessionWindowStart *time.Time `json:"session_window_start"`
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `gorm:"size:20" json:"session_window_status"` // allowed/allowed_warning/rejected
// 关联
Proxy *Proxy `gorm:"foreignKey:ProxyID" json:"proxy,omitempty"`
AccountGroups []AccountGroup `gorm:"foreignKey:AccountID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
GroupIDs []int64 `gorm:"-" json:"group_ids,omitempty"`
Groups []*Group `gorm:"-" json:"groups,omitempty"`
}
func (Account) TableName() string {
return "accounts"
}
// IsActive 检查是否激活
func (a *Account) IsActive() bool {
return a.Status == "active"
}
// IsSchedulable 检查账号是否可调度
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
// IsRateLimited 检查是否处于限流状态
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
// IsOverloaded 检查是否处于过载状态
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
// IsOAuth 检查是否为OAuth类型账号包括oauth和setup-token
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
// CanGetUsage 检查账号是否可以获取usage信息只有oauth类型可以setup-token没有profile权限
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
// GetCredential 获取凭证字段
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// GetModelMapping 获取模型映射配置
// 返回格式: map[请求模型名]实际模型名
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
// 处理map[string]interface{}类型
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
// IsModelSupported 检查请求的模型是否被该账号支持
// 如果没有设置模型映射,则支持所有模型
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true // 没有映射配置,支持所有模型
}
_, exists := mapping[requestedModel]
return exists
}
// GetMappedModel 获取映射后的实际模型名
// 如果没有映射,返回原始模型名
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
// GetBaseURL 获取API基础URL用于apikey类型账号
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com" // 默认URL
}
return baseURL
}
// GetExtraString 从Extra字段获取字符串值
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// IsCustomErrorCodesEnabled 检查是否启用自定义错误码功能(仅适用于 apikey 类型)
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetCustomErrorCodes 获取自定义错误码列表
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
// 处理 []interface{} 类型JSON反序列化后的格式
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
// JSON 数字默认解析为 float64
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
// ShouldHandleErrorCode 检查指定错误码是否应该被处理(停止调度/标记限流等)
// 如果未启用自定义错误码或列表为空,返回 true使用默认策略
// 如果启用且列表非空,只有在列表中的错误码才返回 true
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true // 未启用,使用默认策略
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true // 启用但列表为空fallback到默认策略
}
// 检查是否在自定义列表中
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
// IsInterceptWarmupEnabled 检查是否启用预热请求拦截
// 启用后标题生成、Warmup等预热请求将返回mock响应不消耗上游token
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// =============== OpenAI 相关方法 ===============
// IsOpenAI 检查是否为 OpenAI 平台账号
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
// IsAnthropic 检查是否为 Anthropic 平台账号
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
// IsOpenAIOAuth 检查是否为 OpenAI OAuth 类型账号
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
// IsOpenAIApiKey 检查是否为 OpenAI API Key 类型账号Response 账号)
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
// GetOpenAIBaseURL 获取 OpenAI API 基础 URL
// 对于 API Key 类型账号,从 credentials 中获取 base_url
// 对于 OAuth 类型账号,返回默认的 OpenAI API URL
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com" // OpenAI 默认 API URL
}
// GetOpenAIAccessToken 获取 OpenAI 访问令牌
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
// GetOpenAIRefreshToken 获取 OpenAI 刷新令牌
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
// GetOpenAIIDToken 获取 OpenAI ID TokenJWT包含用户信息
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
// GetOpenAIApiKey 获取 OpenAI API Key用于 Response 账号)
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
// GetOpenAIUserAgent 获取 OpenAI 自定义 User-Agent
// 返回空字符串表示透传原始 User-Agent
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
// GetChatGPTAccountID 获取 ChatGPT 账号 ID从 ID Token 解析)
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
// GetChatGPTUserID 获取 ChatGPT 用户 ID从 ID Token 解析)
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
// GetOpenAIOrganizationID 获取 OpenAI 组织 ID
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
// GetOpenAITokenExpiresAt 获取 OpenAI Token 过期时间
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
// 尝试解析时间
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
// 尝试解析为 Unix 时间戳
if v, ok := a.Credentials["expires_at"].(float64); ok {
t = time.Unix(int64(v), 0)
return &t
}
return nil
}
return &t
}
// IsOpenAITokenExpired 检查 OpenAI Token 是否过期
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false // 没有过期时间信息,假设未过期
}
// 提前 60 秒认为过期,便于刷新
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -1,20 +0,0 @@
package model
import (
"time"
)
type AccountGroup struct {
AccountID int64 `gorm:"primaryKey" json:"account_id"`
GroupID int64 `gorm:"primaryKey" json:"group_id"`
Priority int `gorm:"default:50;not null" json:"priority"` // 分组内优先级
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 关联
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (AccountGroup) TableName() string {
return "account_groups"
}

View File

@@ -1,32 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
type ApiKey struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
Key string `gorm:"uniqueIndex;size:128;not null" json:"key"` // sk-xxx
Name string `gorm:"size:100;not null" json:"name"`
GroupID *int64 `gorm:"index" json:"group_id"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (ApiKey) TableName() string {
return "api_keys"
}
// IsActive 检查是否激活
func (k *ApiKey) IsActive() bool {
return k.Status == "active"
}

View File

@@ -1,73 +0,0 @@
package model
import (
"time"
"gorm.io/gorm"
)
// 订阅类型常量
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
type Group struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
// 订阅功能字段
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)" json:"monthly_limit_usd"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
AccountGroups []AccountGroup `gorm:"foreignKey:GroupID" json:"account_groups,omitempty"`
// 虚拟字段 (不存储到数据库)
AccountCount int64 `gorm:"-" json:"account_count,omitempty"`
}
func (Group) TableName() string {
return "groups"
}
// IsActive 检查是否激活
func (g *Group) IsActive() bool {
return g.Status == "active"
}
// IsSubscriptionType 检查是否为订阅类型分组
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
// IsFreeSubscription 检查是否为免费订阅(不扣余额但有限额)
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
// HasDailyLimit 检查是否有日限额
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
// HasWeeklyLimit 检查是否有周限额
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
// HasMonthlyLimit 检查是否有月限额
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -1,64 +0,0 @@
package model
import (
"gorm.io/gorm"
)
// AutoMigrate 自动迁移所有模型
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&User{},
&ApiKey{},
&Group{},
&Account{},
&AccountGroup{},
&Proxy{},
&RedeemCode{},
&UsageLog{},
&Setting{},
&UserSubscription{},
)
}
// 状态常量
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// 角色常量
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// 平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// 账号类型常量
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// 卡密类型常量
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// 管理员调整类型常量
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)

View File

@@ -1,45 +0,0 @@
package model
import (
"fmt"
"time"
"gorm.io/gorm"
)
type Proxy struct {
ID int64 `gorm:"primaryKey" json:"id"`
Name string `gorm:"size:100;not null" json:"name"`
Protocol string `gorm:"size:20;not null" json:"protocol"` // http/https/socks5
Host string `gorm:"size:255;not null" json:"host"`
Port int `gorm:"not null" json:"port"`
Username string `gorm:"size:100" json:"username"`
Password string `gorm:"size:100" json:"-"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (Proxy) TableName() string {
return "proxies"
}
// IsActive 检查是否激活
func (p *Proxy) IsActive() bool {
return p.Status == "active"
}
// URL 返回代理URL
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
// ProxyWithAccountCount extends Proxy with account count information
type ProxyWithAccountCount struct {
Proxy
AccountCount int64 `json:"account_count"`
}

View File

@@ -1,50 +0,0 @@
package model
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64 `gorm:"primaryKey" json:"id"`
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
UsedBy *int64 `gorm:"index" json:"used_by"`
UsedAt *time.Time `json:"used_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// 订阅类型专用字段
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
// 关联
User *User `gorm:"foreignKey:UsedBy" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
}
func (RedeemCode) TableName() string {
return "redeem_codes"
}
// IsUsed 检查是否已使用
func (r *RedeemCode) IsUsed() bool {
return r.Status == "used"
}
// CanUse 检查是否可以使用
func (r *RedeemCode) CanUse() bool {
return r.Status == "unused"
}
// GenerateRedeemCode 生成唯一的兑换码
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -1,104 +0,0 @@
package model
import (
"time"
)
// Setting 系统设置模型Key-Value存储
type Setting struct {
ID int64 `gorm:"primaryKey" json:"id"`
Key string `gorm:"uniqueIndex;size:100;not null" json:"key"`
Value string `gorm:"type:text;not null" json:"value"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
}
func (Setting) TableName() string {
return "settings"
}
// 设置Key常量
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// 管理员 API Key 前缀(与用户 sk- 前缀区分)
const AdminApiKeyPrefix = "admin-"
// SystemSettings 系统设置结构体用于API响应
type SystemSettings struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"` // 不返回明文密码
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"` // 不返回明文密钥
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// PublicSettings 公开设置(无需登录即可获取)
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -1,67 +0,0 @@
package model
import (
"time"
)
// 消费类型常量
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
ApiKeyID int64 `gorm:"index;not null" json:"api_key_id"`
AccountID int64 `gorm:"index;not null" json:"account_id"`
RequestID string `gorm:"size:64" json:"request_id"`
Model string `gorm:"size:100;index;not null" json:"model"`
// 订阅关联(可选)
GroupID *int64 `gorm:"index" json:"group_id"`
SubscriptionID *int64 `gorm:"index" json:"subscription_id"`
// Token使用量4类
InputTokens int `gorm:"default:0;not null" json:"input_tokens"`
OutputTokens int `gorm:"default:0;not null" json:"output_tokens"`
CacheCreationTokens int `gorm:"default:0;not null" json:"cache_creation_tokens"`
CacheReadTokens int `gorm:"default:0;not null" json:"cache_read_tokens"`
// 详细的缓存创建分类
CacheCreation5mTokens int `gorm:"default:0;not null" json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `gorm:"default:0;not null" json:"cache_creation_1h_tokens"`
// 费用USD
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"input_cost"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
// 元数据
BillingType int8 `gorm:"type:smallint;default:0;not null" json:"billing_type"` // 0=余额 1=订阅
Stream bool `gorm:"default:false;not null" json:"stream"`
DurationMs *int `json:"duration_ms"`
FirstTokenMs *int `json:"first_token_ms"` // 首字时间(流式请求)
CreatedAt time.Time `gorm:"index;not null" json:"created_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
ApiKey *ApiKey `gorm:"foreignKey:ApiKeyID" json:"api_key,omitempty"`
Account *Account `gorm:"foreignKey:AccountID" json:"account,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
Subscription *UserSubscription `gorm:"foreignKey:SubscriptionID" json:"subscription,omitempty"`
}
func (UsageLog) TableName() string {
return "usage_logs"
}
// TotalTokens 总token数
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -1,78 +0,0 @@
package model
import (
"time"
"github.com/lib/pq"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
ID int64 `gorm:"primaryKey" json:"id"`
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
Username string `gorm:"size:100;default:''" json:"username"`
Wechat string `gorm:"size:100;default:''" json:"wechat"`
Notes string `gorm:"type:text;default:''" json:"notes"`
PasswordHash string `gorm:"size:255;not null" json:"-"`
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
Concurrency int `gorm:"default:5;not null" json:"concurrency"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
AllowedGroups pq.Int64Array `gorm:"type:bigint[]" json:"allowed_groups"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 关联
ApiKeys []ApiKey `gorm:"foreignKey:UserID" json:"api_keys,omitempty"`
Subscriptions []UserSubscription `gorm:"foreignKey:UserID" json:"subscriptions,omitempty"`
}
func (User) TableName() string {
return "users"
}
// IsAdmin 检查是否管理员
func (u *User) IsAdmin() bool {
return u.Role == "admin"
}
// IsActive 检查是否激活
func (u *User) IsActive() bool {
return u.Status == "active"
}
// CanBindGroup 检查是否可以绑定指定分组
// 对于标准类型分组:
// - 如果 AllowedGroups 设置了值(非空数组),只能绑定列表中的分组
// - 如果 AllowedGroups 为 nil 或空数组,可以绑定所有非专属分组
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
// 如果设置了 allowed_groups 且不为空,只能绑定指定的分组
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
// 如果没有设置 allowed_groups 或为空数组,可以绑定所有非专属分组
return !isExclusive
}
// SetPassword 设置密码(哈希存储)
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
// CheckPassword 验证密码
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password))
return err == nil
}

View File

@@ -1,157 +0,0 @@
package model
import (
"time"
)
// 订阅状态常量
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// UserSubscription 用户订阅模型
type UserSubscription struct {
ID int64 `gorm:"primaryKey" json:"id"`
UserID int64 `gorm:"index;not null" json:"user_id"`
GroupID int64 `gorm:"index;not null" json:"group_id"`
// 订阅有效期
StartsAt time.Time `gorm:"not null" json:"starts_at"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/expired/suspended
// 滑动窗口起始时间nil = 未激活)
DailyWindowStart *time.Time `json:"daily_window_start"`
WeeklyWindowStart *time.Time `json:"weekly_window_start"`
MonthlyWindowStart *time.Time `json:"monthly_window_start"`
// 当前窗口已用额度USD基于 total_cost 计算)
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"daily_usage_usd"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"weekly_usage_usd"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null" json:"monthly_usage_usd"`
// 管理员分配信息
AssignedBy *int64 `gorm:"index" json:"assigned_by"`
AssignedAt time.Time `gorm:"not null" json:"assigned_at"`
Notes string `gorm:"type:text" json:"notes"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
Group *Group `gorm:"foreignKey:GroupID" json:"group,omitempty"`
AssignedByUser *User `gorm:"foreignKey:AssignedBy" json:"assigned_by_user,omitempty"`
}
func (UserSubscription) TableName() string {
return "user_subscriptions"
}
// IsActive 检查订阅是否有效状态为active且未过期
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
// IsExpired 检查订阅是否已过期
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
// DaysRemaining 返回订阅剩余天数
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
// IsWindowActivated 检查窗口是否已激活
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
// NeedsDailyReset 检查日窗口是否需要重置
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
// NeedsWeeklyReset 检查周窗口是否需要重置
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
// NeedsMonthlyReset 检查月窗口是否需要重置
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
// DailyResetTime 返回日窗口重置时间
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
// WeeklyResetTime 返回周窗口重置时间
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
// MonthlyResetTime 返回月窗口重置时间
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
// CheckDailyLimit 检查是否超出日限额
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true // 无限制
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
// CheckWeeklyLimit 检查是否超出周限额
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true // 无限制
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
// CheckMonthlyLimit 检查是否超出月限额
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true // 无限制
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
// CheckAllLimits 检查所有限额
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -5,10 +5,10 @@ import (
"errors"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
@@ -21,69 +21,66 @@ func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
}
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
func (r *accountRepository) Create(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Account, error) {
var m accountModel
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
account.Groups = make([]*model.Group, 0, len(account.AccountGroups))
for _, ag := range account.AccountGroups {
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
return &account, nil
return accountModelToService(&m), nil
}
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
if crsAccountID == "" {
return nil, nil
}
var account model.Account
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&account).Error
var m accountModel
err := r.db.WithContext(ctx).Where("extra->>'crs_account_id' = ?", crsAccountID).First(&m).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &account, nil
return accountModelToService(&m), nil
}
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
m := accountModelFromService(account)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyAccountModelToService(account, m)
}
return err
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 再删除账号
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
return r.db.WithContext(ctx).Delete(&accountModel{}, id).Error
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
var accounts []accountModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Account{})
db := r.db.WithContext(ctx).Model(&accountModel{})
// Apply filters
if platform != "" {
db = db.Where("platform = ?", platform)
}
@@ -106,67 +103,84 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
// 填充每个 Account 的虚拟字段GroupIDs 和 Groups
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
accounts[i].GroupIDs = make([]int64, 0, len(accounts[i].AccountGroups))
accounts[i].Groups = make([]*model.Group, 0, len(accounts[i].AccountGroups))
for _, ag := range accounts[i].AccountGroups {
accounts[i].GroupIDs = append(accounts[i].GroupIDs, ag.GroupID)
if ag.Group != nil {
accounts[i].Groups = append(accounts[i].Groups, ag.Group)
}
}
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outAccounts, paginationResultFromTotal(total, params), nil
}
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, model.StatusActive).
Where("account_groups.group_id = ? AND accounts.status = ?", groupID, service.StatusActive).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListActive(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, service.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"status": model.StatusError,
"status": service.StatusError,
"error_message": errorMsg,
}).Error
}
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
ag := &accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
@@ -176,131 +190,148 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
Delete(&accountGroupModel{}).Error
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
Where("account_groups.account_id = ?", accountID).
Find(&groups).Error
return groups, err
}
if err != nil {
return nil, err
}
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return outGroups, nil
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&accountGroupModel{}).Error; err != nil {
return err
}
// 添加新绑定
if len(groupIDs) > 0 {
accountGroups := make([]model.AccountGroup, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1, // 使用索引作为优先级
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
if len(groupIDs) == 0 {
return nil
}
return nil
accountGroups := make([]accountGroupModel, 0, len(groupIDs))
for i, groupID := range groupIDs {
accountGroups = append(accountGroups, accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: i + 1,
})
}
return r.db.WithContext(ctx).Create(&accountGroups).Error
}
// ListSchedulable 获取所有可调度的账号
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Where("platform = ?", platform).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Where("(overload_until IS NULL OR overload_until <= ?)", now).
Where("(rate_limit_reset_at IS NULL OR rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
var accounts []accountModel
now := time.Now()
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
Where("account_groups.group_id = ?", groupID).
Where("accounts.platform = ?", platform).
Where("accounts.status = ? AND accounts.schedulable = ?", model.StatusActive, true).
Where("accounts.status = ? AND accounts.schedulable = ?", service.StatusActive, true).
Where("(accounts.overload_until IS NULL OR accounts.overload_until <= ?)", now).
Where("(accounts.rate_limit_reset_at IS NULL OR accounts.rate_limit_reset_at <= ?)", now).
Preload("Proxy").
Order("account_groups.priority ASC, accounts.priority ASC").
Find(&accounts).Error
return accounts, err
if err != nil {
return nil, err
}
outAccounts := make([]service.Account, 0, len(accounts))
for i := range accounts {
outAccounts = append(outAccounts, *accountModelToService(&accounts[i]))
}
return outAccounts, nil
}
// SetRateLimited 标记账号为限流状态(429)
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}
// SetOverloaded 标记账号为过载状态(529)
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
"rate_limit_reset_at": nil,
@@ -308,7 +339,6 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
}).Error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
@@ -319,45 +349,35 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
if end != nil {
updates["session_window_end"] = end
}
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Updates(updates).Error
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).Updates(updates).Error
}
// SetSchedulable 设置账号的调度开关
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
// Get current account to preserve existing Extra data
var account model.Account
var account accountModel
if err := r.db.WithContext(ctx).Select("extra").Where("id = ?", id).First(&account).Error; err != nil {
return err
}
// Initialize Extra if nil
if account.Extra == nil {
account.Extra = make(model.JSONB)
account.Extra = datatypes.JSONMap{}
}
// Merge updates into existing Extra
for k, v := range updates {
account.Extra[k] = v
}
// Save updated Extra
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&accountModel{}).Where("id = ?", id).
Update("extra", account.Extra).Error
}
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
@@ -381,10 +401,10 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
updateMap["status"] = *updates.Status
}
if len(updates.Credentials) > 0 {
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", updates.Credentials)
updateMap["credentials"] = gorm.Expr("COALESCE(credentials,'{}') || ?", datatypes.JSONMap(updates.Credentials))
}
if len(updates.Extra) > 0 {
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", updates.Extra)
updateMap["extra"] = gorm.Expr("COALESCE(extra,'{}') || ?", datatypes.JSONMap(updates.Extra))
}
if len(updateMap) == 0 {
@@ -392,10 +412,178 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
}
result := r.db.WithContext(ctx).
Model(&model.Account{}).
Model(&accountModel{}).
Where("id IN ?", ids).
Clauses(clause.Returning{}).
Updates(updateMap)
return result.RowsAffected, result.Error
}
type accountModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Platform string `gorm:"size:50;not null"`
Type string `gorm:"size:20;not null"`
Credentials datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
Extra datatypes.JSONMap `gorm:"type:jsonb;default:'{}'"`
ProxyID *int64 `gorm:"index"`
Concurrency int `gorm:"default:3;not null"`
Priority int `gorm:"default:50;not null"`
Status string `gorm:"size:20;default:active;not null"`
ErrorMessage string `gorm:"type:text"`
LastUsedAt *time.Time `gorm:"index"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
Schedulable bool `gorm:"default:true;not null"`
RateLimitedAt *time.Time `gorm:"index"`
RateLimitResetAt *time.Time `gorm:"index"`
OverloadUntil *time.Time `gorm:"index"`
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string `gorm:"size:20"`
Proxy *proxyModel `gorm:"foreignKey:ProxyID"`
AccountGroups []accountGroupModel `gorm:"foreignKey:AccountID"`
}
func (accountModel) TableName() string { return "accounts" }
type accountGroupModel struct {
AccountID int64 `gorm:"primaryKey"`
GroupID int64 `gorm:"primaryKey"`
Priority int `gorm:"default:50;not null"`
CreatedAt time.Time `gorm:"not null"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (accountGroupModel) TableName() string { return "account_groups" }
func accountGroupModelToService(m *accountGroupModel) *service.AccountGroup {
if m == nil {
return nil
}
return &service.AccountGroup{
AccountID: m.AccountID,
GroupID: m.GroupID,
Priority: m.Priority,
CreatedAt: m.CreatedAt,
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
}
}
func accountModelToService(m *accountModel) *service.Account {
if m == nil {
return nil
}
var credentials map[string]any
if m.Credentials != nil {
credentials = map[string]any(m.Credentials)
}
var extra map[string]any
if m.Extra != nil {
extra = map[string]any(m.Extra)
}
account := &service.Account{
ID: m.ID,
Name: m.Name,
Platform: m.Platform,
Type: m.Type,
Credentials: credentials,
Extra: extra,
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
Status: m.Status,
ErrorMessage: m.ErrorMessage,
LastUsedAt: m.LastUsedAt,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: m.SessionWindowStatus,
Proxy: proxyModelToService(m.Proxy),
}
if len(m.AccountGroups) > 0 {
account.AccountGroups = make([]service.AccountGroup, 0, len(m.AccountGroups))
account.GroupIDs = make([]int64, 0, len(m.AccountGroups))
account.Groups = make([]*service.Group, 0, len(m.AccountGroups))
for i := range m.AccountGroups {
ag := accountGroupModelToService(&m.AccountGroups[i])
if ag == nil {
continue
}
account.AccountGroups = append(account.AccountGroups, *ag)
account.GroupIDs = append(account.GroupIDs, ag.GroupID)
if ag.Group != nil {
account.Groups = append(account.Groups, ag.Group)
}
}
}
return account
}
func accountModelFromService(a *service.Account) *accountModel {
if a == nil {
return nil
}
var credentials datatypes.JSONMap
if a.Credentials != nil {
credentials = datatypes.JSONMap(a.Credentials)
}
var extra datatypes.JSONMap
if a.Extra != nil {
extra = datatypes.JSONMap(a.Extra)
}
return &accountModel{
ID: a.ID,
Name: a.Name,
Platform: a.Platform,
Type: a.Type,
Credentials: credentials,
Extra: extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
Schedulable: a.Schedulable,
RateLimitedAt: a.RateLimitedAt,
RateLimitResetAt: a.RateLimitResetAt,
OverloadUntil: a.OverloadUntil,
SessionWindowStart: a.SessionWindowStart,
SessionWindowEnd: a.SessionWindowEnd,
SessionWindowStatus: a.SessionWindowStatus,
}
}
func applyAccountModelToService(account *service.Account, m *accountModel) {
if account == nil || m == nil {
return
}
account.ID = m.ID
account.CreatedAt = m.CreatedAt
account.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/datatypes"
"gorm.io/gorm"
)
@@ -34,11 +34,16 @@ func TestAccountRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{
Name: "test-create",
Platform: model.PlatformAnthropic,
Type: model.AccountTypeOAuth,
Status: model.StatusActive,
account := &service.Account{
Name: "test-create",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
}
err := s.repo.Create(s.ctx, account)
@@ -56,7 +61,7 @@ func (s *AccountRepoSuite) TestGetByID_NotFound() {
}
func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
account := accountModelToService(mustCreateAccount(s.T(), s.db, &accountModel{Name: "original"}))
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
@@ -68,7 +73,7 @@ func (s *AccountRepoSuite) TestUpdate() {
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
@@ -78,23 +83,23 @@ func (s *AccountRepoSuite) TestDelete() {
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
s.db.Model(&accountGroupModel{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -111,53 +116,53 @@ func (s *AccountRepoSuite) TestListWithFilters() {
status string
search string
wantCount int
validate func(accounts []model.Account)
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
mustCreateAccount(s.T(), db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: model.PlatformOpenAI,
platform: service.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
validate: func(accounts []service.Account) {
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
mustCreateAccount(s.T(), db, &accountModel{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &accountModel{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: model.AccountTypeApiKey,
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
mustCreateAccount(s.T(), db, &accountModel{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), db, &accountModel{Name: "s2", Status: service.StatusDisabled})
},
status: model.StatusDisabled,
status: service.StatusDisabled,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
validate: func(accounts []service.Account) {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &accountModel{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []model.Account) {
validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
@@ -185,9 +190,9 @@ func (s *AccountRepoSuite) TestListWithFilters() {
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
@@ -199,8 +204,8 @@ func (s *AccountRepoSuite) TestListByGroup() {
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -209,22 +214,22 @@ func (s *AccountRepoSuite) TestListActive() {
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc1",
ProxyID: &proxy.ID,
})
@@ -252,9 +257,9 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
@@ -274,8 +279,8 @@ func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
@@ -289,13 +294,13 @@ func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
@@ -307,16 +312,16 @@ func (s *AccountRepoSuite) TestListSchedulable() {
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
okAcc := mustCreateAccount(s.T(), s.db, &accountModel{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
overloaded := mustCreateAccount(s.T(), s.db, &accountModel{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
rateLimited := mustCreateAccount(s.T(), s.db, &accountModel{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
@@ -334,30 +339,30 @@ func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_Statu
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
@@ -369,7 +374,7 @@ func (s *AccountRepoSuite) TestSetSchedulable() {
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
@@ -381,7 +386,7 @@ func (s *AccountRepoSuite) TestSetOverloaded() {
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
@@ -394,7 +399,7 @@ func (s *AccountRepoSuite) TestSetRateLimited() {
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
@@ -411,7 +416,7 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
@@ -424,20 +429,20 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
@@ -453,9 +458,9 @@ func (s *AccountRepoSuite) TestUpdateSessionWindow() {
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{
account := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-extra",
Extra: model.JSONB{"a": "1"},
Extra: datatypes.JSONMap{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
@@ -466,12 +471,12 @@ func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
@@ -483,9 +488,9 @@ func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{
mustCreateAccount(s.T(), s.db, &accountModel{
Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID},
Extra: datatypes.JSONMap{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
@@ -509,8 +514,8 @@ func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
@@ -526,13 +531,13 @@ func (s *AccountRepoSuite) TestBulkUpdate() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"},
Credentials: datatypes.JSONMap{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"},
Credentials: datatypes.JSONMap{"new_key": "new_value"},
})
s.Require().NoError(err)
@@ -542,13 +547,13 @@ func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
a1 := mustCreateAccount(s.T(), s.db, &accountModel{
Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"},
Extra: datatypes.JSONMap{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"},
Extra: datatypes.JSONMap{"new_key": "new_val"},
})
s.Require().NoError(err)
@@ -564,14 +569,14 @@ func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []model.Account) []int64 {
func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -19,42 +19,51 @@ func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
}
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &key, nil
return apiKeyModelToService(&m), nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
var m apiKeyModel
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &apiKey, nil
return apiKeyModelToService(&m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
m := apiKeyModelFromService(key)
err := r.db.WithContext(ctx).Model(m).Select("name", "group_id", "status", "updated_at").Updates(m).Error
if err == nil {
applyApiKeyModelToService(key, m)
}
return err
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
return r.db.WithContext(ctx).Delete(&apiKeyModel{}, id).Error
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -64,36 +73,31 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
var keys []apiKeyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID)
db := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -103,24 +107,19 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outKeys, paginationResultFromTotal(total, params), nil
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
var keys []apiKeyModel
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
db := r.db.WithContext(ctx).Model(&apiKeyModel{})
if userID > 0 {
db = db.Where("user_id = ?", userID)
@@ -135,12 +134,16 @@ func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
return nil, err
}
return keys, nil
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyModelToService(&keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
result := r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
return result.RowsAffected, result.Error
@@ -149,6 +152,66 @@ func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Model(&apiKeyModel{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
type apiKeyModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
Key string `gorm:"uniqueIndex;size:128;not null"`
Name string `gorm:"size:100;not null"`
GroupID *int64 `gorm:"index"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (apiKeyModel) TableName() string { return "api_keys" }
func apiKeyModelToService(m *apiKeyModel) *service.ApiKey {
if m == nil {
return nil
}
return &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
GroupID: m.GroupID,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func apiKeyModelFromService(k *service.ApiKey) *apiKeyModel {
if k == nil {
return nil
}
return &apiKeyModel{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
}
}
func applyApiKeyModelToService(key *service.ApiKey, m *apiKeyModel) {
if key == nil || m == nil {
return
}
key.ID = m.ID
key.CreatedAt = m.CreatedAt
key.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,13 +32,13 @@ func TestApiKeyRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
key := &model.ApiKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: model.StatusActive,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, key)
@@ -56,15 +56,15 @@ func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
Status: service.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -84,16 +84,16 @@ func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: model.StatusActive,
})
Status: service.StatusActive,
}))
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
@@ -102,18 +102,18 @@ func (s *ApiKeyRepoSuite) TestUpdate() {
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status)
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear"})
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
})
}))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
@@ -127,8 +127,8 @@ func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -144,9 +144,9 @@ func (s *ApiKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -155,9 +155,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID() {
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
@@ -172,9 +172,9 @@ func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -184,12 +184,12 @@ func (s *ApiKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -200,10 +200,10 @@ func (s *ApiKeyRepoSuite) TestListByGroupID() {
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -213,8 +213,8 @@ func (s *ApiKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -228,9 +228,9 @@ func (s *ApiKeyRepoSuite) TestExistsByKey() {
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
@@ -239,9 +239,9 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
@@ -249,8 +249,8 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
@@ -260,12 +260,12 @@ func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
k1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -283,16 +283,16 @@ func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
key := apiKeyModelToService(mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
Status: service.StatusActive,
}))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
@@ -303,7 +303,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.Status = service.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
@@ -312,7 +312,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status)
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
@@ -330,7 +330,7 @@ func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
k2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",

View File

@@ -0,0 +1,20 @@
package repository
import "gorm.io/gorm"
// AutoMigrate runs schema migrations for all repository persistence models.
// Persistence models are defined within individual `*_repo.go` files.
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&userModel{},
&apiKeyModel{},
&groupModel{},
&accountModel{},
&accountGroupModel{},
&proxyModel{},
&redeemCodeModel{},
&usageLogModel{},
&settingModel{},
&userSubscriptionModel{},
)
}

View File

@@ -6,21 +6,25 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
func mustCreateUser(t *testing.T, db *gorm.DB, u *userModel) *userModel {
t.Helper()
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = model.RoleUser
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = model.StatusActive
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
@@ -32,16 +36,16 @@ func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
func mustCreateGroup(t *testing.T, db *gorm.DB, g *groupModel) *groupModel {
t.Helper()
if g.Platform == "" {
g.Platform = model.PlatformAnthropic
g.Platform = service.PlatformAnthropic
}
if g.Status == "" {
g.Status = model.StatusActive
g.Status = service.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard
g.SubscriptionType = service.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
@@ -53,7 +57,7 @@ func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
func mustCreateProxy(t *testing.T, db *gorm.DB, p *proxyModel) *proxyModel {
t.Helper()
if p.Protocol == "" {
p.Protocol = "http"
@@ -65,7 +69,7 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
p.Port = 8080
}
if p.Status == "" {
p.Status = model.StatusActive
p.Status = service.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
@@ -77,25 +81,25 @@ func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
func mustCreateAccount(t *testing.T, db *gorm.DB, a *accountModel) *accountModel {
t.Helper()
if a.Platform == "" {
a.Platform = model.PlatformAnthropic
a.Platform = service.PlatformAnthropic
}
if a.Type == "" {
a.Type = model.AccountTypeOAuth
a.Type = service.AccountTypeOAuth
}
if a.Status == "" {
a.Status = model.StatusActive
a.Status = service.StatusActive
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = model.JSONB{}
a.Credentials = datatypes.JSONMap{}
}
if a.Extra == nil {
a.Extra = model.JSONB{}
a.Extra = datatypes.JSONMap{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
@@ -107,10 +111,10 @@ func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Accou
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *apiKeyModel) *apiKeyModel {
t.Helper()
if k.Status == "" {
k.Status = model.StatusActive
k.Status = service.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
@@ -122,13 +126,13 @@ func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *redeemCodeModel) *redeemCodeModel {
t.Helper()
if c.Status == "" {
c.Status = model.StatusUnused
c.Status = service.StatusUnused
}
if c.Type == "" {
c.Type = model.RedeemTypeBalance
c.Type = service.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
@@ -137,10 +141,10 @@ func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *userSubscriptionModel) *userSubscriptionModel {
t.Helper()
if s.Status == "" {
s.Status = model.SubscriptionStatusActive
s.Status = service.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
@@ -164,9 +168,10 @@ func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{
require.NoError(t, db.Create(&accountGroupModel{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
CreatedAt: time.Now(),
}).Error, "create account_group")
}

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -20,38 +20,50 @@ func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db}
}
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
err := r.db.WithContext(ctx).Create(group).Error
func (r *groupRepository) Create(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
var m groupModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return &group, nil
group := groupModelToService(&m)
count, _ := r.GetAccountCount(ctx, group.ID)
group.AccountCount = count
return group, nil
}
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
func (r *groupRepository) Update(ctx context.Context, group *service.Group) error {
m := groupModelFromService(group)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyGroupModelToService(group, m)
}
return err
}
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
return r.db.WithContext(ctx).Delete(&groupModel{}, id).Error
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
var groups []groupModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Group{})
db := r.db.WithContext(ctx).Model(&groupModel{})
// Apply filters
if platform != "" {
@@ -72,68 +84,71 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return groups, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outGroups, paginationResultFromTotal(total, params), nil
}
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
var groups []groupModel
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", service.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
return nil, err
}
// 获取每个分组的账号数量
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
count, _ := r.GetAccountCount(ctx, groups[i].ID)
groups[i].AccountCount = count
outGroups = append(outGroups, *groupModelToService(&groups[i]))
}
return groups, nil
// 获取每个分组的账号数量
for i := range outGroups {
count, _ := r.GetAccountCount(ctx, outGroups[i].ID)
outGroups[i].AccountCount = count
}
return outGroups, nil
}
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
err := r.db.WithContext(ctx).Model(&groupModel{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
err := r.db.WithContext(ctx).Table("account_groups").Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
result := r.db.WithContext(ctx).Exec("DELETE FROM account_groups WHERE group_id = ?", groupID)
return result.RowsAffected, result.Error
}
@@ -145,46 +160,42 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Table("user_subscriptions").
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
Pluck("user_id", &affectedUserIDs).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
if err := tx.Exec("DELETE FROM user_subscriptions WHERE group_id = ?", id).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
if err := tx.Exec("UPDATE api_keys SET group_id = NULL WHERE group_id = ?", id).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
if err := tx.Exec(
"UPDATE users SET allowed_groups = array_remove(allowed_groups, ?) WHERE ? = ANY(allowed_groups)",
id, id,
).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
if err := tx.Exec("DELETE FROM account_groups WHERE group_id = ?", id).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&groupModel{}, id).Error; err != nil {
return err
}
@@ -196,3 +207,75 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return affectedUserIDs, nil
}
type groupModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"uniqueIndex;size:100;not null"`
Description string `gorm:"type:text"`
Platform string `gorm:"size:50;default:anthropic;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null"`
IsExclusive bool `gorm:"default:false;not null"`
Status string `gorm:"size:20;default:active;not null"`
SubscriptionType string `gorm:"size:20;default:standard;not null"`
DailyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
WeeklyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
MonthlyLimitUSD *float64 `gorm:"type:decimal(20,8)"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (groupModel) TableName() string { return "groups" }
func groupModelToService(m *groupModel) *service.Group {
if m == nil {
return nil
}
return &service.Group{
ID: m.ID,
Name: m.Name,
Description: m.Description,
Platform: m.Platform,
RateMultiplier: m.RateMultiplier,
IsExclusive: m.IsExclusive,
Status: m.Status,
SubscriptionType: m.SubscriptionType,
DailyLimitUSD: m.DailyLimitUSD,
WeeklyLimitUSD: m.WeeklyLimitUSD,
MonthlyLimitUSD: m.MonthlyLimitUSD,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func groupModelFromService(sg *service.Group) *groupModel {
if sg == nil {
return nil
}
return &groupModel{
ID: sg.ID,
Name: sg.Name,
Description: sg.Description,
Platform: sg.Platform,
RateMultiplier: sg.RateMultiplier,
IsExclusive: sg.IsExclusive,
Status: sg.Status,
SubscriptionType: sg.SubscriptionType,
DailyLimitUSD: sg.DailyLimitUSD,
WeeklyLimitUSD: sg.WeeklyLimitUSD,
MonthlyLimitUSD: sg.MonthlyLimitUSD,
CreatedAt: sg.CreatedAt,
UpdatedAt: sg.UpdatedAt,
}
}
func applyGroupModelToService(group *service.Group, m *groupModel) {
if group == nil || m == nil {
return
}
group.ID = m.ID
group.CreatedAt = m.CreatedAt
group.UpdatedAt = m.UpdatedAt
}

View File

@@ -6,8 +6,8 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,10 +32,10 @@ func TestGroupRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{
group := &service.Group{
Name: "test-create",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, group)
@@ -53,7 +53,7 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
}
func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
group := groupModelToService(mustCreateGroup(s.T(), s.db, &groupModel{Name: "original"}))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
@@ -65,7 +65,7 @@ func (s *GroupRepoSuite) TestUpdate() {
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
@@ -77,8 +77,8 @@ func (s *GroupRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -87,28 +87,28 @@ func (s *GroupRepoSuite) TestList() {
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
s.Require().Equal(service.PlatformOpenAI, groups[0].Platform)
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Status: service.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", IsExclusive: true})
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
@@ -118,24 +118,24 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
g1 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g1",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
g2 := mustCreateGroup(s.T(), s.db, &groupModel{
Name: "g2",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
Platform: service.PlatformAnthropic,
Status: service.StatusActive,
IsExclusive: true,
})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
@@ -146,8 +146,8 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "active1", Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "inactive1", Status: service.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -156,11 +156,11 @@ func (s *GroupRepoSuite) TestListActive() {
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "g3", Platform: service.PlatformAnthropic, Status: service.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name)
@@ -169,7 +169,7 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() {
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
mustCreateGroup(s.T(), s.db, &groupModel{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
@@ -183,9 +183,9 @@ func (s *GroupRepoSuite) TestExistsByName() {
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
@@ -195,7 +195,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() {
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
@@ -205,8 +205,8 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
@@ -219,10 +219,10 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
g := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)

View File

@@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -94,7 +93,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
if err := AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}

View File

@@ -0,0 +1,16 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}

View File

@@ -2,10 +2,10 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
@@ -19,37 +19,47 @@ func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db}
}
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error
func (r *proxyRepository) Create(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
}
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
var m proxyModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
}
return &proxy, nil
return proxyModelToService(&m), nil
}
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error
func (r *proxyRepository) Update(ctx context.Context, proxy *service.Proxy) error {
m := proxyModelFromService(proxy)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyProxyModelToService(proxy, m)
}
return err
}
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
return r.db.WithContext(ctx).Delete(&proxyModel{}, id).Error
}
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
var proxies []proxyModel
var total int64
db := r.db.WithContext(ctx).Model(&model.Proxy{})
db := r.db.WithContext(ctx).Model(&proxyModel{})
// Apply filters
if protocol != "" {
@@ -71,29 +81,31 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return proxies, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outProxies, paginationResultFromTotal(total, params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).Where("status = ?", service.StatusActive).Find(&proxies).Error
if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyModelToService(&proxies[i]))
}
return outProxies, nil
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
err := r.db.WithContext(ctx).Model(&proxyModel{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
Count(&count).Error
if err != nil {
@@ -105,7 +117,7 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}).
err := r.db.WithContext(ctx).Table("accounts").
Where("proxy_id = ?", proxyID).
Count(&count).Error
return count, err
@@ -119,7 +131,7 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
var results []result
err := r.db.WithContext(ctx).
Model(&model.Account{}).
Table("accounts").
Select("proxy_id, COUNT(*) as count").
Where("proxy_id IS NOT NULL").
Group("proxy_id").
@@ -136,10 +148,10 @@ func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
var proxies []proxyModel
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
Where("status = ?", service.StatusActive).
Order("created_at DESC").
Find(&proxies).Error
if err != nil {
@@ -153,13 +165,78 @@ func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]mod
}
// Build result with account counts
result := make([]model.ProxyWithAccountCount, len(proxies))
for i, proxy := range proxies {
result[i] = model.ProxyWithAccountCount{
Proxy: proxy,
AccountCount: counts[proxy.ID],
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxy := proxyModelToService(&proxies[i])
if proxy == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxy,
AccountCount: counts[proxy.ID],
})
}
return result, nil
}
type proxyModel struct {
ID int64 `gorm:"primaryKey"`
Name string `gorm:"size:100;not null"`
Protocol string `gorm:"size:20;not null"`
Host string `gorm:"size:255;not null"`
Port int `gorm:"not null"`
Username string `gorm:"size:100"`
Password string `gorm:"size:100"`
Status string `gorm:"size:20;default:active;not null"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (proxyModel) TableName() string { return "proxies" }
func proxyModelToService(m *proxyModel) *service.Proxy {
if m == nil {
return nil
}
return &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Username: m.Username,
Password: m.Password,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func proxyModelFromService(p *service.Proxy) *proxyModel {
if p == nil {
return nil
}
return &proxyModel{
ID: p.ID,
Name: p.Name,
Protocol: p.Protocol,
Host: p.Host,
Port: p.Port,
Username: p.Username,
Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
}
}
func applyProxyModelToService(proxy *service.Proxy, m *proxyModel) {
if proxy == nil || m == nil {
return
}
proxy.ID = m.ID
proxy.CreatedAt = m.CreatedAt
proxy.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -33,12 +33,12 @@ func TestProxyRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
proxy := &service.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
@@ -56,7 +56,7 @@ func (s *ProxyRepoSuite) TestGetByID_NotFound() {
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy := proxyModelToService(mustCreateProxy(s.T(), s.db, &proxyModel{Name: "original"}))
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
@@ -68,7 +68,7 @@ func (s *ProxyRepoSuite) TestUpdate() {
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
@@ -80,8 +80,8 @@ func (s *ProxyRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -90,8 +90,8 @@ func (s *ProxyRepoSuite) TestList() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
@@ -100,18 +100,18 @@ func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2", Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
@@ -122,8 +122,8 @@ func (s *ProxyRepoSuite) TestListWithFilters_Search() {
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "active1", Status: service.StatusActive})
mustCreateProxy(s.T(), s.db, &proxyModel{Name: "inactive1", Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
@@ -134,7 +134,7 @@ func (s *ProxyRepoSuite) TestListActive() {
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
@@ -153,7 +153,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
@@ -170,10 +170,10 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
@@ -181,7 +181,7 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
proxy := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
@@ -191,12 +191,12 @@ func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
@@ -215,24 +215,24 @@ func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Status: model.StatusActive,
Status: service.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Status: model.StatusActive,
Status: service.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p3-inactive",
Status: model.StatusDisabled,
Status: service.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
@@ -248,7 +248,7 @@ func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p1 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
@@ -258,7 +258,7 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
p2 := mustCreateProxy(s.T(), s.db, &proxyModel{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
@@ -273,9 +273,9 @@ func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")

View File

@@ -4,10 +4,8 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
@@ -20,48 +18,61 @@ func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db}
}
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
}
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
if len(codes) == 0 {
return nil
}
models := make([]redeemCodeModel, 0, len(codes))
for i := range codes {
m := redeemCodeModelFromService(&codes[i])
if m != nil {
models = append(models, *m)
}
}
return r.db.WithContext(ctx).Create(&models).Error
}
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &code, nil
return redeemCodeModelToService(&m), nil
}
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
var m redeemCodeModel
err := r.db.WithContext(ctx).Where("code = ?", code).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &redeemCode, nil
return redeemCodeModelToService(&m), nil
}
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
return r.db.WithContext(ctx).Delete(&redeemCodeModel{}, id).Error
}
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
var codes []redeemCodeModel
var total int64
db := r.db.WithContext(ctx).Model(&model.RedeemCode{})
db := r.db.WithContext(ctx).Model(&redeemCodeModel{})
// Apply filters
if codeType != "" {
db = db.Where("type = ?", codeType)
}
@@ -81,29 +92,29 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return codes, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outCodes, paginationResultFromTotal(total, params), nil
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
m := redeemCodeModelFromService(code)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyRedeemCodeModelToService(code, m)
}
return err
}
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
result := r.db.WithContext(ctx).Model(&redeemCodeModel{}).
Where("id = ? AND status = ?", id, service.StatusUnused).
Updates(map[string]any{
"status": model.StatusUsed,
"status": service.StatusUsed,
"used_by": userID,
"used_at": now,
})
@@ -116,22 +127,93 @@ func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return nil
}
// ListByUser returns all redeem codes used by a specific user
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
if limit <= 0 {
limit = 10
}
var codes []redeemCodeModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("used_by = ?", userID).
Order("used_at DESC").
Limit(limit).
Find(&codes).Error
if err != nil {
return nil, err
}
return codes, nil
outCodes := make([]service.RedeemCode, 0, len(codes))
for i := range codes {
outCodes = append(outCodes, *redeemCodeModelToService(&codes[i]))
}
return outCodes, nil
}
type redeemCodeModel struct {
ID int64 `gorm:"primaryKey"`
Code string `gorm:"uniqueIndex;size:32;not null"`
Type string `gorm:"size:20;default:balance;not null"`
Value float64 `gorm:"type:decimal(20,8);not null"`
Status string `gorm:"size:20;default:unused;not null"`
UsedBy *int64 `gorm:"index"`
UsedAt *time.Time
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
GroupID *int64 `gorm:"index"`
ValidityDays int `gorm:"default:30"`
User *userModel `gorm:"foreignKey:UsedBy"`
Group *groupModel `gorm:"foreignKey:GroupID"`
}
func (redeemCodeModel) TableName() string { return "redeem_codes" }
func redeemCodeModelToService(m *redeemCodeModel) *service.RedeemCode {
if m == nil {
return nil
}
return &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
}
}
func redeemCodeModelFromService(r *service.RedeemCode) *redeemCodeModel {
if r == nil {
return nil
}
return &redeemCodeModel{
ID: r.ID,
Code: r.Code,
Type: r.Type,
Value: r.Value,
Status: r.Status,
UsedBy: r.UsedBy,
UsedAt: r.UsedAt,
Notes: r.Notes,
CreatedAt: r.CreatedAt,
GroupID: r.GroupID,
ValidityDays: r.ValidityDays,
}
}
func applyRedeemCodeModelToService(code *service.RedeemCode, m *redeemCodeModel) {
if code == nil || m == nil {
return
}
code.ID = m.ID
code.CreatedAt = m.CreatedAt
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -34,11 +33,11 @@ func TestRedeemCodeRepoSuite(t *testing.T) {
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{
code := &service.RedeemCode{
Code: "TEST-CREATE",
Type: model.RedeemTypeBalance,
Type: service.RedeemTypeBalance,
Value: 100,
Status: model.StatusUnused,
Status: service.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
@@ -51,9 +50,9 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
codes := []service.RedeemCode{
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
@@ -74,7 +73,7 @@ func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "GET-BY-CODE", Type: service.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
@@ -89,7 +88,7 @@ func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TO-DELETE", Type: service.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete")
@@ -101,8 +100,8 @@ func (s *RedeemCodeRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-1", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "LIST-2", Type: service.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -111,28 +110,28 @@ func (s *RedeemCodeRepoSuite) TestList() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-BAL", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "STAT-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status)
s.Require().Equal(service.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "BETA-CODE", Type: service.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
@@ -141,10 +140,10 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription,
Type: service.RedeemTypeSubscription,
GroupID: &group.ID,
})
@@ -158,7 +157,7 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
code := redeemCodeModelToService(mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "UPDATE-ME", Type: service.RedeemTypeBalance, Value: 10}))
code.Value = 50
err := s.repo.Update(s.ctx, code)
@@ -172,23 +171,23 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "USE-ME", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status)
s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Status: service.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
@@ -200,8 +199,8 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Status: service.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
@@ -211,22 +210,22 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c1 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-1",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c2 := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "USER-2",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
@@ -240,13 +239,13 @@ func (s *RedeemCodeRepoSuite) TestListByUser() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "WITH-GRP",
Type: model.RedeemTypeSubscription,
Status: model.StatusUsed,
Type: service.RedeemTypeSubscription,
Status: service.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
@@ -260,11 +259,11 @@ func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &redeemCodeModel{
Code: "DEF-LIM",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
Type: service.RedeemTypeBalance,
Status: service.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
@@ -278,16 +277,16 @@ func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-rc"})
codes := []model.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
codes := []service.RedeemCode{
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
@@ -305,9 +304,9 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
s.db.Model(&redeemCodeModel{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")

View File

@@ -6,33 +6,27 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type settingRepository struct {
db *gorm.DB
}
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db}
}
// Get 根据Key获取设置值
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
var m settingModel
err := r.db.WithContext(ctx).Where("key = ?", key).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
}
return &setting, nil
return settingModelToService(&m), nil
}
// GetValue 获取设置值字符串
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
@@ -41,9 +35,8 @@ func (r *settingRepository) GetValue(ctx context.Context, key string) (string, e
return setting.Value, nil
}
// Set 设置值(存在则更新,不存在则创建)
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
@@ -52,12 +45,11 @@ func (r *settingRepository) Set(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error
}).Create(m).Error
}
// GetMultiple 批量获取设置
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting
var settings []settingModel
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil {
return nil, err
@@ -70,11 +62,10 @@ func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map
return result, nil
}
// SetMultiple 批量设置值
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings {
setting := &model.Setting{
m := &settingModel{
Key: key,
Value: value,
UpdatedAt: time.Now(),
@@ -82,7 +73,7 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value", "updated_at"}),
}).Create(setting).Error; err != nil {
}).Create(m).Error; err != nil {
return err
}
}
@@ -90,9 +81,8 @@ func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string
})
}
// GetAll 获取所有设置
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting
var settings []settingModel
err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil {
return nil, err
@@ -105,7 +95,27 @@ func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, erro
return result, nil
}
// Delete 删除设置
func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&settingModel{}).Error
}
type settingModel struct {
ID int64 `gorm:"primaryKey"`
Key string `gorm:"uniqueIndex;size:100;not null"`
Value string `gorm:"type:text;not null"`
UpdatedAt time.Time `gorm:"not null"`
}
func (settingModel) TableName() string { return "settings" }
func settingModelToService(m *settingModel) *service.Setting {
if m == nil {
return nil
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
@@ -30,7 +29,7 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
TokenCount int64 `gorm:"column:token_count"`
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as request_count,
COALESCE(SUM(input_tokens + output_tokens), 0) as token_count
@@ -46,24 +45,29 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
m := usageLogModelFromService(log)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUsageLogModelToService(log, m)
}
return err
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*service.UsageLog, error) {
var log usageLogModel
err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
}
return &log, nil
return usageLogModelToService(&log), nil
}
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("user_id = ?", userID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("user_id = ?", userID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -73,24 +77,14 @@ func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, param
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("api_key_id = ?", apiKeyID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("api_key_id = ?", apiKeyID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -100,17 +94,7 @@ func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
// UserStats 用户使用统计
@@ -125,7 +109,7 @@ type UserStats struct {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
@@ -147,47 +131,47 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
today := timezone.Today()
// 总用户数
r.db.WithContext(ctx).Model(&model.User{}).Count(&stats.TotalUsers)
r.db.WithContext(ctx).Model(&userModel{}).Count(&stats.TotalUsers)
// 今日新增用户数
r.db.WithContext(ctx).Model(&model.User{}).
r.db.WithContext(ctx).Model(&userModel{}).
Where("created_at >= ?", today).
Count(&stats.TodayNewUsers)
// 今日活跃用户数 (今日有请求的用户)
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Distinct("user_id").
Where("created_at >= ?", today).
Count(&stats.ActiveUsers)
// 总 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&apiKeyModel{}).Count(&stats.TotalApiKeys)
// 活跃 API Key 数
r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("status = ?", model.StatusActive).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("status = ?", service.StatusActive).
Count(&stats.ActiveApiKeys)
// 总账户数
r.db.WithContext(ctx).Model(&model.Account{}).Count(&stats.TotalAccounts)
r.db.WithContext(ctx).Model(&accountModel{}).Count(&stats.TotalAccounts)
// 正常账户数 (schedulable=true, status=active)
r.db.WithContext(ctx).Model(&model.Account{}).
Where("status = ? AND schedulable = ?", model.StatusActive, true).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ? AND schedulable = ?", service.StatusActive, true).
Count(&stats.NormalAccounts)
// 异常账户数 (status=error)
r.db.WithContext(ctx).Model(&model.Account{}).
Where("status = ?", model.StatusError).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("status = ?", service.StatusError).
Count(&stats.ErrorAccounts)
// 限流账户数
r.db.WithContext(ctx).Model(&model.Account{}).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("rate_limited_at IS NOT NULL AND rate_limit_reset_at > ?", time.Now()).
Count(&stats.RateLimitAccounts)
// 过载账户数
r.db.WithContext(ctx).Model(&model.Account{}).
r.db.WithContext(ctx).Model(&accountModel{}).
Where("overload_until IS NOT NULL AND overload_until > ?", time.Now()).
Count(&stats.OverloadAccounts)
@@ -202,7 +186,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -235,7 +219,7 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -263,11 +247,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil
}
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).Where("account_id = ?", accountID)
db := r.db.WithContext(ctx).Model(&usageLogModel{}).Where("account_id = ?", accountID)
if err := db.Count(&total).Error; err != nil {
return nil, nil, err
@@ -277,57 +261,47 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
Order("id DESC").
Find(&logs).Error
return logs, nil, err
return usageLogModelsToService(logs), nil, err
}
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
return r.db.WithContext(ctx).Delete(&usageLogModel{}, id).Error
}
// GetAccountTodayStats 获取账号今日统计
@@ -340,7 +314,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
Cost float64 `gorm:"column:cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -368,7 +342,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
Cost float64 `gorm:"column:cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
@@ -499,12 +473,12 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
today := timezone.Today()
// API Key 统计
r.db.WithContext(ctx).Model(&model.ApiKey{}).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ?", userID).
Count(&stats.TotalApiKeys)
r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("user_id = ? AND status = ?", userID, model.StatusActive).
r.db.WithContext(ctx).Model(&apiKeyModel{}).
Where("user_id = ? AND status = ?", userID, service.StatusActive).
Count(&stats.ActiveApiKeys)
// 累计 Token 统计
@@ -518,7 +492,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TotalActualCost float64 `gorm:"column:total_actual_cost"`
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -552,7 +526,7 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
TodayCost float64 `gorm:"column:today_cost"`
TodayActualCost float64 `gorm:"column:today_actual_cost"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
@@ -591,7 +565,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
dateFormat = "YYYY-MM-DD"
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
@@ -618,7 +592,7 @@ func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
model,
COUNT(*) as requests,
@@ -644,11 +618,11 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
var logs []usageLogModel
var total int64
db := r.db.WithContext(ctx).Model(&model.UsageLog{})
db := r.db.WithContext(ctx).Model(&usageLogModel{})
// Apply filters
if filters.UserID > 0 {
@@ -675,17 +649,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return usageLogModelsToService(logs), paginationResultFromTotal(total, params), nil
}
// UsageStats represents usage statistics
@@ -713,7 +677,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"`
TotalCost float64 `gorm:"column:total_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("user_id IN ?", userIDs).
Group("user_id").
@@ -733,7 +697,7 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
UserID int64 `gorm:"column:user_id"`
TodayCost float64 `gorm:"column:today_cost"`
}
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("user_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("user_id IN ? AND created_at >= ?", userIDs, today).
Group("user_id").
@@ -773,7 +737,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"`
TotalCost float64 `gorm:"column:total_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost").
Where("api_key_id IN ?", apiKeyIDs).
Group("api_key_id").
@@ -793,7 +757,7 @@ func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
ApiKeyID int64 `gorm:"column:api_key_id"`
TodayCost float64 `gorm:"column:today_cost"`
}
err = r.db.WithContext(ctx).Model(&model.UsageLog{}).
err = r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost").
Where("api_key_id IN ? AND created_at >= ?", apiKeyIDs, today).
Group("api_key_id").
@@ -822,7 +786,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
dateFormat = "YYYY-MM-DD"
}
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, ?) as date,
COUNT(*) as requests,
@@ -854,7 +818,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
db := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
model,
COUNT(*) as requests,
@@ -896,7 +860,7 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
AverageDurationMs float64 `gorm:"column:avg_duration_ms"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
@@ -950,7 +914,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
ActualCost float64 `gorm:"column:actual_cost"`
}
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
err := r.db.WithContext(ctx).Model(&usageLogModel{}).
Select(`
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
COUNT(*) as requests,
@@ -1011,7 +975,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var avgDuration struct {
AvgDurationMs float64 `gorm:"column:avg_duration_ms"`
}
r.db.WithContext(ctx).Model(&model.UsageLog{}).
r.db.WithContext(ctx).Model(&usageLogModel{}).
Select("COALESCE(AVG(duration_ms), 0) as avg_duration_ms").
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
Scan(&avgDuration)
@@ -1090,3 +1054,137 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Models: models,
}, nil
}
type usageLogModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
ApiKeyID int64 `gorm:"index;not null"`
AccountID int64 `gorm:"index;not null"`
RequestID string `gorm:"size:64"`
Model string `gorm:"size:100;index;not null"`
GroupID *int64 `gorm:"index"`
SubscriptionID *int64 `gorm:"index"`
InputTokens int `gorm:"default:0;not null"`
OutputTokens int `gorm:"default:0;not null"`
CacheCreationTokens int `gorm:"default:0;not null"`
CacheReadTokens int `gorm:"default:0;not null"`
CacheCreation5mTokens int `gorm:"default:0;not null"`
CacheCreation1hTokens int `gorm:"default:0;not null"`
InputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null"`
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null"`
BillingType int8 `gorm:"type:smallint;default:0;not null"`
Stream bool `gorm:"default:false;not null"`
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time `gorm:"index;not null"`
User *userModel `gorm:"foreignKey:UserID"`
ApiKey *apiKeyModel `gorm:"foreignKey:ApiKeyID"`
Account *accountModel `gorm:"foreignKey:AccountID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
Subscription *userSubscriptionModel `gorm:"foreignKey:SubscriptionID"`
}
func (usageLogModel) TableName() string { return "usage_logs" }
func usageLogModelToService(m *usageLogModel) *service.UsageLog {
if m == nil {
return nil
}
return &service.UsageLog{
ID: m.ID,
UserID: m.UserID,
ApiKeyID: m.ApiKeyID,
AccountID: m.AccountID,
RequestID: m.RequestID,
Model: m.Model,
GroupID: m.GroupID,
SubscriptionID: m.SubscriptionID,
InputTokens: m.InputTokens,
OutputTokens: m.OutputTokens,
CacheCreationTokens: m.CacheCreationTokens,
CacheReadTokens: m.CacheReadTokens,
CacheCreation5mTokens: m.CacheCreation5mTokens,
CacheCreation1hTokens: m.CacheCreation1hTokens,
InputCost: m.InputCost,
OutputCost: m.OutputCost,
CacheCreationCost: m.CacheCreationCost,
CacheReadCost: m.CacheReadCost,
TotalCost: m.TotalCost,
ActualCost: m.ActualCost,
RateMultiplier: m.RateMultiplier,
BillingType: m.BillingType,
Stream: m.Stream,
DurationMs: m.DurationMs,
FirstTokenMs: m.FirstTokenMs,
CreatedAt: m.CreatedAt,
User: userModelToService(m.User),
ApiKey: apiKeyModelToService(m.ApiKey),
Account: accountModelToService(m.Account),
Group: groupModelToService(m.Group),
Subscription: userSubscriptionModelToService(m.Subscription),
}
}
func usageLogModelsToService(models []usageLogModel) []service.UsageLog {
out := make([]service.UsageLog, 0, len(models))
for i := range models {
if s := usageLogModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func usageLogModelFromService(log *service.UsageLog) *usageLogModel {
if log == nil {
return nil
}
return &usageLogModel{
ID: log.ID,
UserID: log.UserID,
ApiKeyID: log.ApiKeyID,
AccountID: log.AccountID,
RequestID: log.RequestID,
Model: log.Model,
GroupID: log.GroupID,
SubscriptionID: log.SubscriptionID,
InputTokens: log.InputTokens,
OutputTokens: log.OutputTokens,
CacheCreationTokens: log.CacheCreationTokens,
CacheReadTokens: log.CacheReadTokens,
CacheCreation5mTokens: log.CacheCreation5mTokens,
CacheCreation1hTokens: log.CacheCreation1hTokens,
InputCost: log.InputCost,
OutputCost: log.OutputCost,
CacheCreationCost: log.CacheCreationCost,
CacheReadCost: log.CacheReadCost,
TotalCost: log.TotalCost,
ActualCost: log.ActualCost,
RateMultiplier: log.RateMultiplier,
BillingType: log.BillingType,
Stream: log.Stream,
DurationMs: log.DurationMs,
FirstTokenMs: log.FirstTokenMs,
CreatedAt: log.CreatedAt,
}
}
func applyUsageLogModelToService(log *service.UsageLog, m *usageLogModel) {
if log == nil || m == nil {
return
}
log.ID = m.ID
log.CreatedAt = m.CreatedAt
}

View File

@@ -7,10 +7,10 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -32,8 +32,8 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
log := &model.UsageLog{
func (s *UsageLogRepoSuite) createUsageLog(user *userModel, apiKey *apiKeyModel, account *accountModel, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -51,11 +51,11 @@ func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKe
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-create"})
log := &model.UsageLog{
log := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -72,9 +72,9 @@ func (s *UsageLogRepoSuite) TestCreate() {
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -92,9 +92,9 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -108,9 +108,9 @@ func (s *UsageLogRepoSuite) TestDelete() {
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -124,9 +124,9 @@ func (s *UsageLogRepoSuite) TestListByUser() {
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
@@ -140,9 +140,9 @@ func (s *UsageLogRepoSuite) TestListByApiKey() {
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -155,9 +155,9 @@ func (s *UsageLogRepoSuite) TestListByAccount() {
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -175,9 +175,9 @@ func (s *UsageLogRepoSuite) TestGetUserStats() {
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -194,29 +194,29 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{
userToday := mustCreateUser(s.T(), s.db, &userModel{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &model.User{
userOld := mustCreateUser(s.T(), s.db, &userModel{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: service.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
accNormal := mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-error", Status: service.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &accountModel{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{
logToday := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -233,7 +233,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{
logOld := &service.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -247,7 +247,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{
logPerf := &service.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
@@ -293,9 +293,9 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -308,9 +308,9 @@ func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -323,11 +323,11 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -348,10 +348,10 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
@@ -370,9 +370,9 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -395,9 +395,9 @@ func maxTime(a, b time.Time) time.Time {
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -414,9 +414,9 @@ func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -433,9 +433,9 @@ func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -452,14 +452,14 @@ func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -472,7 +472,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -485,7 +485,7 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{
log3 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -508,9 +508,9 @@ func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
@@ -528,9 +528,9 @@ func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -545,9 +545,9 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -564,14 +564,14 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -584,7 +584,7 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -611,9 +611,9 @@ func (s *UsageLogRepoSuite) TestGetUserModelStats() {
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -639,9 +639,9 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -658,13 +658,13 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -677,7 +677,7 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -712,14 +712,14 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days
log1 := &model.UsageLog{
log1 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -732,7 +732,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
log2 := &service.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -758,7 +758,7 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
@@ -774,11 +774,11 @@ func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
@@ -796,10 +796,10 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
@@ -815,9 +815,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
@@ -834,9 +834,9 @@ func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
@@ -848,9 +848,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
@@ -867,9 +867,9 @@ func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &apiKeyModel{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &accountModel{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)

View File

@@ -2,12 +2,13 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"gorm.io/gorm"
)
@@ -19,48 +20,56 @@ func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).Where("email = ?", email).First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Save(user).Error
func (r *userRepository) Update(ctx context.Context, user *service.User) error {
m := userModelFromService(user)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserModelToService(user, m)
}
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
return r.db.WithContext(ctx).Delete(&userModel{}, id).Error
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
var users []userModel
var total int64
db := r.db.WithContext(ctx).Model(&model.User{})
db := r.db.WithContext(ctx).Model(&userModel{})
// Apply filters
if status != "" {
@@ -89,17 +98,20 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Batch load subscriptions for all users (avoid N+1)
if len(users) > 0 {
userIDs := make([]int64, len(users))
userMap := make(map[int64]*model.User, len(users))
userMap := make(map[int64]*service.User, len(users))
outUsers := make([]service.User, 0, len(users))
for i := range users {
userIDs[i] = users[i].ID
userMap[users[i].ID] = &users[i]
u := userModelToService(&users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Query active subscriptions with groups in one query
var subscriptions []model.UserSubscription
var subscriptions []userSubscriptionModel
if err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id IN ? AND status = ?", userIDs, model.SubscriptionStatusActive).
Where("user_id IN ? AND status = ?", userIDs, service.SubscriptionStatusActive).
Find(&subscriptions).Error; err != nil {
return nil, nil, err
}
@@ -107,32 +119,29 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// Associate subscriptions with users
for i := range subscriptions {
if user, ok := userMap[subscriptions[i].UserID]; ok {
user.Subscriptions = append(user.Subscriptions, subscriptions[i])
user.Subscriptions = append(user.Subscriptions, *userSubscriptionModelToService(&subscriptions[i]))
}
}
return outUsers, paginationResultFromTotal(total, params), nil
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
outUsers := make([]service.User, 0, len(users))
for i := range users {
outUsers = append(outUsers, *userModelToService(&users[i]))
}
return users, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return outUsers, paginationResultFromTotal(total, params), nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
if result.Error != nil {
@@ -145,34 +154,104 @@ func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount flo
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
return r.db.WithContext(ctx).Model(&userModel{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
err := r.db.WithContext(ctx).Model(&userModel{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
}
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
result := r.db.WithContext(ctx).Model(&userModel{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
var m userModel
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Where("role = ? AND status = ?", service.RoleAdmin, service.StatusActive).
Order("id ASC").
First(&user).Error
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
return userModelToService(&m), nil
}
type userModel struct {
ID int64 `gorm:"primaryKey"`
Email string `gorm:"uniqueIndex;size:255;not null"`
Username string `gorm:"size:100;default:''"`
Wechat string `gorm:"size:100;default:''"`
Notes string `gorm:"type:text;default:''"`
PasswordHash string `gorm:"size:255;not null"`
Role string `gorm:"size:20;default:user;not null"`
Balance float64 `gorm:"type:decimal(20,8);default:0;not null"`
Concurrency int `gorm:"default:5;not null"`
Status string `gorm:"size:20;default:active;not null"`
AllowedGroups pq.Int64Array `gorm:"type:bigint[]"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (userModel) TableName() string { return "users" }
func userModelToService(m *userModel) *service.User {
if m == nil {
return nil
}
return &service.User{
ID: m.ID,
Email: m.Email,
Username: m.Username,
Wechat: m.Wechat,
Notes: m.Notes,
PasswordHash: m.PasswordHash,
Role: m.Role,
Balance: m.Balance,
Concurrency: m.Concurrency,
Status: m.Status,
AllowedGroups: []int64(m.AllowedGroups),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func userModelFromService(u *service.User) *userModel {
if u == nil {
return nil
}
return &userModel{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
AllowedGroups: pq.Int64Array(u.AllowedGroups),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func applyUserModelToService(dst *service.User, src *userModel) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -7,7 +7,6 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
@@ -35,11 +34,12 @@ func TestUserRepoSuite(t *testing.T) {
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &model.User{
Email: "create@test.com",
Username: "testuser",
Role: model.RoleUser,
Status: model.StatusActive,
user := &service.User{
Email: "create@test.com",
Username: "testuser",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, user)
@@ -57,7 +57,7 @@ func (s *UserRepoSuite) TestGetByID_NotFound() {
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
@@ -70,7 +70,7 @@ func (s *UserRepoSuite) TestGetByEmail_NotFound() {
}
func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
user := userModelToService(mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com", Username: "original"}))
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
@@ -82,7 +82,7 @@ func (s *UserRepoSuite) TestUpdate() {
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
@@ -94,8 +94,8 @@ func (s *UserRepoSuite) TestDelete() {
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
@@ -104,28 +104,28 @@ func (s *UserRepoSuite) TestList() {
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com", Status: service.StatusActive})
mustCreateUser(s.T(), s.db, &userModel{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status)
s.Require().Equal(service.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
mustCreateUser(s.T(), s.db, &userModel{Email: "user@test.com", Role: service.RoleUser})
mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role)
s.Require().Equal(service.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
mustCreateUser(s.T(), s.db, &userModel{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &userModel{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
@@ -134,8 +134,8 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
@@ -144,8 +144,8 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
mustCreateUser(s.T(), s.db, &userModel{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &userModel{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
@@ -154,19 +154,19 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub@test.com", Status: service.StatusActive})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
_ = mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
})
@@ -179,29 +179,29 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &model.User{
target := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -211,7 +211,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
@@ -222,7 +222,7 @@ func (s *UserRepoSuite) TestUpdateBalance() {
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
@@ -233,7 +233,7 @@ func (s *UserRepoSuite) TestUpdateBalance_Negative() {
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
@@ -244,7 +244,7 @@ func (s *UserRepoSuite) TestDeductBalance() {
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
@@ -252,7 +252,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
@@ -265,7 +265,7 @@ func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
@@ -276,7 +276,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency() {
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
@@ -289,7 +289,7 @@ func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
@@ -304,11 +304,11 @@ func (s *UserRepoSuite) TestExistsByEmail() {
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{
userA := mustCreateUser(s.T(), s.db, &userModel{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
})
@@ -325,7 +325,7 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
})
@@ -338,15 +338,15 @@ func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{
admin1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "admin1@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "admin2@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -355,10 +355,10 @@ func (s *UserRepoSuite) TestGetFirstAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "user@example.com",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
@@ -366,15 +366,15 @@ func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{
mustCreateUser(s.T(), s.db, &userModel{
Email: "disabled@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
activeAdmin := mustCreateUser(s.T(), s.db, &userModel{
Email: "active@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
@@ -385,26 +385,26 @@ func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
// --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{
user1 := mustCreateUser(s.T(), s.db, &userModel{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &model.User{
user2 := mustCreateUser(s.T(), s.db, &userModel{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &model.User{
_ = mustCreateUser(s.T(), s.db, &userModel{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
@@ -441,7 +441,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -4,111 +4,113 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type userSubscriptionRepository struct {
db *gorm.DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
}
// Create 创建订阅
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
err := r.db.WithContext(ctx).Create(sub).Error
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *service.UserSubscription) error {
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Create(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
// GetByID 根据ID获取订阅
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("User").
Preload("Group").
Preload("AssignedByUser").
First(&sub, id).Error
First(&m, id).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
var m userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ? AND status = ? AND expires_at > ?",
userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error
userID, groupID, service.SubscriptionStatusActive, time.Now()).
First(&m).Error
if err != nil {
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
return userSubscriptionModelToService(&m), nil
}
// Update 更新订阅
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *service.UserSubscription) error {
sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error
m := userSubscriptionModelFromService(sub)
err := r.db.WithContext(ctx).Save(m).Error
if err == nil {
applyUserSubscriptionModelToService(sub, m)
}
return err
}
// Delete 删除订阅
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
return r.db.WithContext(ctx).Delete(&userSubscriptionModel{}, id).Error
}
// ListByUserID 获取用户的所有订阅
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&subs).Error
return subs, err
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// ListActiveByUserID 获取用户的所有有效订阅
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND status = ? AND expires_at > ?",
userID, model.SubscriptionStatusActive, time.Now()).
userID, service.SubscriptionStatusActive, time.Now()).
Order("created_at DESC").
Find(&subs).Error
return subs, err
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{}).Where("group_id = ?", groupID)
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).Where("group_id = ?", groupID)
if err := query.Count(&total).Error; err != nil {
return nil, nil, err
}
@@ -124,26 +126,14 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
}
// List 获取所有订阅(分页,支持筛选)
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
var subs []userSubscriptionModel
var total int64
query := r.db.WithContext(ctx).Model(&model.UserSubscription{})
query := r.db.WithContext(ctx).Model(&userSubscriptionModel{})
if userID != nil {
query = query.Where("user_id = ?", *userID)
}
@@ -170,22 +160,87 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}, nil
return userSubscriptionModelsToService(subs), paginationResultFromTotal(total, params), nil
}
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
}
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", subscriptionID).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": start,
"weekly_window_start": start,
"monthly_window_start": start,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// IncrementUsage 增加使用量
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
return r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": gorm.Expr("daily_usage_usd + ?", costUSD),
@@ -195,131 +250,150 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}).Error
}
// ResetDailyUsage 重置日使用量
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_usage_usd": 0,
"daily_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetWeeklyUsage 重置周使用量
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"weekly_usage_usd": 0,
"weekly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ResetMonthlyUsage 重置月使用量
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"monthly_usage_usd": 0,
"monthly_window_start": newWindowStart,
"updated_at": time.Now(),
}).Error
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"daily_window_start": activateTime,
"weekly_window_start": activateTime,
"monthly_window_start": activateTime,
"updated_at": time.Now(),
}).Error
}
// UpdateStatus 更新订阅状态
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"status": status,
"updated_at": time.Now(),
}).Error
}
// ExtendExpiry 延长订阅过期时间
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"expires_at": newExpiresAt,
"updated_at": time.Now(),
}).Error
}
// UpdateNotes 更新订阅备注
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
"notes": notes,
"updated_at": time.Now(),
}).Error
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
return subs, err
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
result := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
"status": model.SubscriptionStatusExpired,
"status": service.SubscriptionStatusExpired,
"updated_at": time.Now(),
})
return result.RowsAffected, result.Error
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
Count(&count).Error
return count > 0, err
// Extra repository helpers (currently used only by integration tests).
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]service.UserSubscription, error) {
var subs []userSubscriptionModel
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", service.SubscriptionStatusActive, time.Now()).
Find(&subs).Error
if err != nil {
return nil, err
}
return userSubscriptionModelsToService(subs), nil
}
// CountByGroupID 获取分组的订阅数量
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ?", groupID).
Count(&count).Error
return count, err
}
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
err := r.db.WithContext(ctx).Model(&userSubscriptionModel{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
groupID, model.SubscriptionStatusActive, time.Now()).
groupID, service.SubscriptionStatusActive, time.Now()).
Count(&count).Error
return count, err
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&userSubscriptionModel{})
return result.RowsAffected, result.Error
}
type userSubscriptionModel struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `gorm:"index;not null"`
GroupID int64 `gorm:"index;not null"`
StartsAt time.Time `gorm:"not null"`
ExpiresAt time.Time `gorm:"not null"`
Status string `gorm:"size:20;default:active;not null"`
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
WeeklyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
MonthlyUsageUSD float64 `gorm:"type:decimal(20,10);default:0;not null"`
AssignedBy *int64 `gorm:"index"`
AssignedAt time.Time `gorm:"not null"`
Notes string `gorm:"type:text"`
CreatedAt time.Time `gorm:"not null"`
UpdatedAt time.Time `gorm:"not null"`
User *userModel `gorm:"foreignKey:UserID"`
Group *groupModel `gorm:"foreignKey:GroupID"`
AssignedByUser *userModel `gorm:"foreignKey:AssignedBy"`
}
func (userSubscriptionModel) TableName() string { return "user_subscriptions" }
func userSubscriptionModelToService(m *userSubscriptionModel) *service.UserSubscription {
if m == nil {
return nil
}
return &service.UserSubscription{
ID: m.ID,
UserID: m.UserID,
GroupID: m.GroupID,
StartsAt: m.StartsAt,
ExpiresAt: m.ExpiresAt,
Status: m.Status,
DailyWindowStart: m.DailyWindowStart,
WeeklyWindowStart: m.WeeklyWindowStart,
MonthlyWindowStart: m.MonthlyWindowStart,
DailyUsageUSD: m.DailyUsageUSD,
WeeklyUsageUSD: m.WeeklyUsageUSD,
MonthlyUsageUSD: m.MonthlyUsageUSD,
AssignedBy: m.AssignedBy,
AssignedAt: m.AssignedAt,
Notes: m.Notes,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
User: userModelToService(m.User),
Group: groupModelToService(m.Group),
AssignedByUser: userModelToService(m.AssignedByUser),
}
}
func userSubscriptionModelsToService(models []userSubscriptionModel) []service.UserSubscription {
out := make([]service.UserSubscription, 0, len(models))
for i := range models {
if s := userSubscriptionModelToService(&models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func userSubscriptionModelFromService(s *service.UserSubscription) *userSubscriptionModel {
if s == nil {
return nil
}
return &userSubscriptionModel{
ID: s.ID,
UserID: s.UserID,
GroupID: s.GroupID,
StartsAt: s.StartsAt,
ExpiresAt: s.ExpiresAt,
Status: s.Status,
DailyWindowStart: s.DailyWindowStart,
WeeklyWindowStart: s.WeeklyWindowStart,
MonthlyWindowStart: s.MonthlyWindowStart,
DailyUsageUSD: s.DailyUsageUSD,
WeeklyUsageUSD: s.WeeklyUsageUSD,
MonthlyUsageUSD: s.MonthlyUsageUSD,
AssignedBy: s.AssignedBy,
AssignedAt: s.AssignedAt,
Notes: s.Notes,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
}
}
func applyUserSubscriptionModelToService(sub *service.UserSubscription, m *userSubscriptionModel) {
if sub == nil || m == nil {
return
}
sub.ID = m.ID
sub.CreatedAt = m.CreatedAt
sub.UpdatedAt = m.UpdatedAt
}

View File

@@ -7,8 +7,8 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -33,13 +33,13 @@ func TestUserSubscriptionRepoSuite(t *testing.T) {
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-create"})
sub := &model.UserSubscription{
sub := &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
@@ -54,14 +54,14 @@ func (s *UserSubscriptionRepoSuite) TestCreate() {
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &userModel{Email: "admin@test.com", Role: service.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
})
@@ -82,14 +82,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-update"})
sub := userSubscriptionModelToService(mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
}))
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
@@ -101,12 +101,12 @@ func (s *UserSubscriptionRepoSuite) TestUpdate() {
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -120,12 +120,12 @@ func (s *UserSubscriptionRepoSuite) TestDelete() {
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -141,14 +141,14 @@ func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-active"})
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
@@ -158,14 +158,14 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-expired"})
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
@@ -176,20 +176,20 @@ func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnor
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -202,46 +202,46 @@ func (s *UserSubscriptionRepoSuite) TestListByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -258,13 +258,13 @@ func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -275,20 +275,20 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -299,20 +299,20 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -323,37 +323,37 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -368,12 +368,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -386,12 +386,12 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -408,12 +408,12 @@ func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
@@ -431,12 +431,12 @@ func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
@@ -454,12 +454,12 @@ func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
})
@@ -477,30 +477,30 @@ func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -514,12 +514,12 @@ func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
user := mustCreateUser(s.T(), s.db, &userModel{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -534,19 +534,19 @@ func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -556,19 +556,19 @@ func (s *UserSubscriptionRepoSuite) TestListExpired() {
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -577,22 +577,22 @@ func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
@@ -608,20 +608,20 @@ func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -631,20 +631,20 @@ func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
user1 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &userModel{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
})
@@ -656,19 +656,19 @@ func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
Status: service.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
@@ -683,19 +683,19 @@ func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
// --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
user := mustCreateUser(s.T(), s.db, &userModel{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &groupModel{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
active := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
expiredActive := mustCreateSubscription(s.T(), s.db, &userSubscriptionModel{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
@@ -729,5 +729,5 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -84,7 +83,11 @@ func validateAdminApiKey(
return false
}
c.Set(string(ContextKeyUser), admin)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key")
return true
}
@@ -121,12 +124,16 @@ func validateJWTForAdmin(
}
// 检查管理员权限
if user.Role != model.RoleAdmin {
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), user)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt")
return true

View File

@@ -1,7 +1,7 @@
package middleware
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -10,15 +10,14 @@ import (
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取用户
user, exists := GetUserFromContext(c)
if !exists {
role, ok := GetUserRoleFromContext(c)
if !ok {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if user.Role != model.RoleAdmin {
if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}

View File

@@ -5,11 +5,9 @@ import (
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
@@ -46,7 +44,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
@@ -121,28 +119,32 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) {
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*model.ApiKey)
apiKey, ok := value.(*service.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) {
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*model.UserSubscription)
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}

View File

@@ -0,0 +1,28 @@
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}

View File

@@ -4,7 +4,6 @@ import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -62,19 +61,14 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
return
}
// 将用户信息存入上下文
c.Set(string(ContextKeyUser), user)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.

View File

@@ -8,6 +8,8 @@ type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键

View File

@@ -0,0 +1,324 @@
package service
import "time"
type Account struct {
ID int64
Name string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool
RateLimitedAt *time.Time
RateLimitResetAt *time.Time
OverloadUntil *time.Time
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
Proxy *Proxy
AccountGroups []AccountGroup
GroupIDs []int64
Groups []*Group
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
}
now := time.Now()
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
return true
}
func (a *Account) IsRateLimited() bool {
if a.RateLimitResetAt == nil {
return false
}
return time.Now().Before(*a.RateLimitResetAt)
}
func (a *Account) IsOverloaded() bool {
if a.OverloadUntil == nil {
return false
}
return time.Now().Before(*a.OverloadUntil)
}
func (a *Account) IsOAuth() bool {
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}
func (a *Account) GetCredential(key string) string {
if a.Credentials == nil {
return ""
}
if v, ok := a.Credentials[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["model_mapping"]
if !ok || raw == nil {
return nil
}
if m, ok := raw.(map[string]any); ok {
result := make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
result[k] = s
}
}
if len(result) > 0 {
return result
}
}
return nil
}
func (a *Account) IsModelSupported(requestedModel string) bool {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return true
}
_, exists := mapping[requestedModel]
return exists
}
func (a *Account) GetMappedModel(requestedModel string) string {
mapping := a.GetModelMapping()
if len(mapping) == 0 {
return requestedModel
}
if mappedModel, exists := mapping[requestedModel]; exists {
return mappedModel
}
return requestedModel
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
if baseURL == "" {
return "https://api.anthropic.com"
}
return baseURL
}
func (a *Account) GetExtraString(key string) string {
if a.Extra == nil {
return ""
}
if v, ok := a.Extra[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
func (a *Account) GetCustomErrorCodes() []int {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["custom_error_codes"]
if !ok || raw == nil {
return nil
}
if arr, ok := raw.([]any); ok {
result := make([]int, 0, len(arr))
for _, v := range arr {
if f, ok := v.(float64); ok {
result = append(result, int(f))
}
}
return result
}
return nil
}
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
if !a.IsCustomErrorCodesEnabled() {
return true
}
codes := a.GetCustomErrorCodes()
if len(codes) == 0 {
return true
}
for _, code := range codes {
if code == statusCode {
return true
}
}
return false
}
func (a *Account) IsInterceptWarmupEnabled() bool {
if a.Credentials == nil {
return false
}
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
func (a *Account) IsOpenAI() bool {
return a.Platform == PlatformOpenAI
}
func (a *Account) IsAnthropic() bool {
return a.Platform == PlatformAnthropic
}
func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
}
}
return "https://api.openai.com"
}
func (a *Account) GetOpenAIAccessToken() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("access_token")
}
func (a *Account) GetOpenAIRefreshToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("refresh_token")
}
func (a *Account) GetOpenAIIDToken() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")
}
func (a *Account) GetOpenAIUserAgent() string {
if !a.IsOpenAI() {
return ""
}
return a.GetCredential("user_agent")
}
func (a *Account) GetChatGPTAccountID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_account_id")
}
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("chatgpt_user_id")
}
func (a *Account) GetOpenAIOrganizationID() string {
if !a.IsOpenAIOAuth() {
return ""
}
return a.GetCredential("organization_id")
}
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
if !a.IsOpenAIOAuth() {
return nil
}
expiresAtStr := a.GetCredential("expires_at")
if expiresAtStr == "" {
return nil
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
if v, ok := a.Credentials["expires_at"].(float64); ok {
tt := time.Unix(int64(v), 0)
return &tt
}
return nil
}
return &t
}
func (a *Account) IsOpenAITokenExpired() bool {
expiresAt := a.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false
}
return time.Now().Add(60 * time.Second).After(*expiresAt)
}

View File

@@ -0,0 +1,13 @@
package service
import "time"
type AccountGroup struct {
AccountID int64
GroupID int64
Priority int
CreatedAt time.Time
Account *Account
Group *Group
}

View File

@@ -6,7 +6,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -15,29 +14,29 @@ var (
)
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
}
// 创建账号
account := &model.Account{
account := &Account{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
// 根据平台执行不同的测试逻辑
switch account.Platform {
case model.PlatformAnthropic:
case PlatformAnthropic:
// TODO: 测试Anthropic API凭证
return nil
case model.PlatformOpenAI:
case PlatformOpenAI:
// TODO: 测试OpenAI API凭证
return nil
case model.PlatformGemini:
case PlatformGemini:
// TODO: 测试Gemini API凭证
return nil
default:

View File

@@ -14,7 +14,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/gin-gonic/gin"
@@ -127,7 +126,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
}
// testClaudeAccountConnection tests an Anthropic Claude account's connection
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
// Determine the model to use
@@ -254,7 +253,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
}
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
ctx := c.Request.Context()
// Default to openai.DefaultTestModel for OpenAI testing

View File

@@ -7,24 +7,23 @@ import (
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Create(ctx context.Context, log *UsageLog) error
GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
@@ -44,7 +43,7 @@ type UsageLogRepository interface {
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
// Account stats
@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
}
// Setup Token账号根据session_window推算没有profile scope无法调用usage API
if account.Type == model.AccountTypeSetupToken {
if account.Type == AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计
s.addWindowStats(ctx, account, usage)
@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
}
// addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
if usage.FiveHour == nil {
return
}
@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
}
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
info := &UsageInfo{}
// 如果有session_window信息

View File

@@ -7,62 +7,61 @@ import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
ClearAccountError(ctx context.Context, id int64) (*Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
@@ -252,7 +251,7 @@ func NewAdminService(
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: "user", // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: model.StatusActive,
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: StatusActive,
AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
code, err := model.GenerateRedeemCode()
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
adjustmentRecord := &RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminConcurrency,
Type: AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: model.StatusUsed,
Status: StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
return s.userRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := model.GenerateRedeemCode()
code, err := GenerateRedeemCode()
if err != nil {
log.Printf("failed to generate adjustment redeem code: %v", err)
return user, nil
}
adjustmentRecord := &model.RedeemCode{
adjustmentRecord := &RedeemCode{
Code: code,
Type: model.AdjustmentTypeAdminBalance,
Type: AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
Status: StatusUsed,
UsedBy: &user.ID,
Notes: notes,
}
@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
platform := input.Platform
if platform == "" {
platform = model.PlatformAnthropic
platform = PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard
subscriptionType = SubscriptionTypeStandard
}
group := &model.Group{
group := &Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: model.StatusActive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
account := &model.Account{
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{
Name: input.Name,
Platform: input.Platform,
Type: input.Type,
Credentials: model.JSONB(input.Credentials),
Extra: model.JSONB(input.Extra),
Credentials: input.Credentials,
Extra: input.Extra,
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Type = input.Type
}
if len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials)
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra)
account.Extra = input.Extra
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = model.StatusActive
account.Status = StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
proxy := &model.Proxy{
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
proxy := &Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil
return []Account{}, 0, nil
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription {
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
}
}
codes := make([]model.RedeemCode, 0, input.Count)
codes := make([]RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
codeValue, err := model.GenerateRedeemCode()
codeValue, err := GenerateRedeemCode()
if err != nil {
return nil, err
}
code := model.RedeemCode{
code := RedeemCode{
Code: codeValue,
Type: input.Type,
Value: input.Value,
Status: model.StatusUnused,
Status: StatusUnused,
}
// 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription {
if input.Type == RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = model.StatusExpired
code.Status = StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}

View File

@@ -0,0 +1,20 @@
package service
import "time"
type ApiKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
@@ -30,17 +29,17 @@ const (
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
@@ -168,7 +167,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +178,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User,
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -235,12 +234,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &model.ApiKey{
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
@@ -251,7 +250,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -260,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -269,7 +268,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -289,7 +288,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -364,7 +363,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -408,7 +407,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -434,7 +433,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0)
availableGroups := make([]Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
@@ -445,7 +444,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -454,7 +453,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)

View File

@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
@@ -64,12 +63,12 @@ func NewAuthService(
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
// 创建用户
user := &model.User{
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: model.RoleUser,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
}
// Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)

View File

@@ -7,7 +7,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// 错误定义
@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
}
// checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
}
// 检查订阅状态
if subData.Status != model.SubscriptionStatusActive {
if subData.Status != SubscriptionStatusActive {
return ErrSubscriptionInvalid
}
@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}

View File

@@ -11,8 +11,6 @@ import (
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
type CRSSyncService struct {
@@ -180,7 +178,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
),
}
var proxies []model.Proxy
var proxies []Proxy
if input.SyncProxies {
proxies, _ = s.proxyRepo.ListActive(ctx)
}
@@ -197,7 +195,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
if targetType == "" {
targetType = "oauth"
}
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken {
item.Action = "skipped"
item.Error = "unsupported authType: " + targetType
result.Skipped++
@@ -268,12 +266,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Platform: PlatformAnthropic,
Type: targetType,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
@@ -288,7 +286,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
// 🔄 Refresh OAuth token after creation
if targetType == model.AccountTypeOAuth {
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
@@ -301,11 +299,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// Update existing
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Platform = PlatformAnthropic
existing.Type = targetType
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
@@ -323,7 +321,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
// 🔄 Refresh OAuth token after update
if targetType == model.AccountTypeOAuth {
if targetType == AccountTypeOAuth {
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
@@ -385,12 +383,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformAnthropic,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
@@ -410,11 +408,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformAnthropic
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
@@ -508,12 +506,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
@@ -538,11 +536,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeOAuth
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
@@ -629,12 +627,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
}
if existing == nil {
account := &model.Account{
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformOpenAI,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
Concurrency: concurrency,
Priority: priority,
@@ -654,11 +652,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformOpenAI
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
@@ -683,9 +681,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
return result, nil
}
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
out := make(model.JSONB)
func mergeMap(existing map[string]any, updates map[string]any) map[string]any {
out := make(map[string]any, len(existing)+len(updates))
for k, v := range existing {
out[k] = v
}
@@ -695,7 +692,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
return out
}
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) {
if !enabled || src == nil {
return nil, nil
}
@@ -731,14 +728,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac
}
// Create new proxy
proxy := &model.Proxy{
proxy := &Proxy{
Name: defaultProxyName(defaultName, protocol, host, port),
Protocol: protocol,
Host: host,
Port: port,
Username: username,
Password: password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
@@ -897,8 +894,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT
// refreshOAuthToken attempts to refresh OAuth token for a synced account
// Returns updated credentials or nil if refresh failed/not applicable
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
if account.Type != model.AccountTypeOAuth {
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any {
if account.Type != AccountTypeOAuth {
return nil
}
@@ -906,7 +903,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
var err error
switch account.Platform {
case model.PlatformAnthropic:
case PlatformAnthropic:
if s.oauthService == nil {
return nil
}
@@ -931,7 +928,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
newCredentials["scope"] = tokenInfo.Scope
}
}
case model.PlatformOpenAI:
case PlatformOpenAI:
if s.openaiOAuthService == nil {
return nil
}
@@ -956,5 +953,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
return nil
}
return model.JSONB(newCredentials)
return newCredentials
}

View File

@@ -0,0 +1,96 @@
package service
// Status constants
const (
StatusActive = "active"
StatusDisabled = "disabled"
StatusError = "error"
StatusUnused = "unused"
StatusUsed = "used"
StatusExpired = "expired"
)
// Role constants
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// Platform constants
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
)
// Account type constants
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
)
// Redeem type constants
const (
RedeemTypeBalance = "balance"
RedeemTypeConcurrency = "concurrency"
RedeemTypeSubscription = "subscription"
)
// Admin adjustment type constants
const (
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
)
// Group subscription type constants
const (
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
)
// Subscription status constants
const (
SubscriptionStatusActive = "active"
SubscriptionStatusExpired = "expired"
SubscriptionStatusSuspended = "suspended"
)
// Setting keys
const (
// 注册设置
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// OEM设置
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"

View File

@@ -11,7 +11,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
model.SettingKeySmtpHost,
model.SettingKeySmtpPort,
model.SettingKeySmtpUsername,
model.SettingKeySmtpPassword,
model.SettingKeySmtpFrom,
model.SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS,
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[model.SettingKeySmtpHost]
host := settings[SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[model.SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName],
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}

View File

@@ -17,7 +17,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -265,12 +264,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -289,19 +288,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
}
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
// 检查模型支持
@@ -341,12 +340,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case model.AccountTypeOAuth, model.AccountTypeSetupToken:
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -357,7 +356,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
}
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
@@ -372,10 +371,7 @@ const (
retryDelay = 3 * time.Second // 重试等待时间
)
// shouldRetryUpstreamError 判断是否应该重试上游错误
// OAuth/Setup Token 账号:仅 403 重试
// API Key 账号:未配置的错误码重试
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
@@ -386,7 +382,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
// 解析请求获取model和stream
@@ -412,7 +408,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
// 应用模型映射仅对apikey类型账号
originalModel := req.Model
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
// 替换请求体中的模型名
@@ -504,10 +500,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -631,7 +627,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
@@ -686,7 +682,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
@@ -717,7 +713,7 @@ type streamingResult struct {
firstTokenMs *int
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -856,7 +852,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -915,10 +911,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription // 可选:订阅信息
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -952,14 +948,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
@@ -1038,9 +1034,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
@@ -1113,10 +1109,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}

View File

@@ -0,0 +1,48 @@
package service
import "time"
type Group struct {
ID int64
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
Status string
SubscriptionType string
DailyLimitUSD *float64
WeeklyLimitUSD *float64
MonthlyLimitUSD *float64
CreatedAt time.Time
UpdatedAt time.Time
AccountGroups []AccountGroup
AccountCount int64
}
func (g *Group) IsActive() bool {
return g.Status == StatusActive
}
func (g *Group) IsSubscriptionType() bool {
return g.SubscriptionType == SubscriptionTypeSubscription
}
func (g *Group) IsFreeSubscription() bool {
return g.IsSubscriptionType() && g.RateMultiplier == 0
}
func (g *Group) HasDailyLimit() bool {
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
}
func (g *Group) HasWeeklyLimit() bool {
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
}
func (g *Group) HasMonthlyLimit() bool {
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -15,16 +14,16 @@ var (
)
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService {
}
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
}
// 创建分组
group := &model.Group{
Name: req.Name,
Description: req.Description,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: model.StatusActive,
group := &Group{
Name: req.Name,
Description: req.Description,
Platform: PlatformAnthropic,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
}
// GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)
@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar
}
// ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
}
// Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)

View File

@@ -6,7 +6,6 @@ import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr
}
// RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")

View File

@@ -16,7 +16,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
)
@@ -119,12 +118,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
@@ -139,19 +138,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// 2. Get schedulable OpenAI accounts
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
// Check model support
@@ -189,15 +188,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -209,7 +208,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
}
// Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles)
@@ -234,7 +233,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
}
@@ -296,7 +295,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
@@ -312,14 +311,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case model.AccountTypeApiKey:
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -340,7 +339,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header
@@ -380,7 +379,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// Check custom error codes
@@ -436,7 +435,7 @@ type openaiStreamingResult struct {
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -552,7 +551,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
}
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
@@ -618,10 +617,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription
}
// RecordUsage records usage and deducts balance
@@ -660,14 +659,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
billingType = BillingTypeSubscription
}
// Create usage log
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
}
// RefreshAccountToken refreshes token for an OpenAI account
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
if !account.IsOpenAI() {
return nil, fmt.Errorf("account is not an OpenAI account")
}

View File

@@ -0,0 +1,35 @@
package service
import (
"fmt"
"time"
)
type Proxy struct {
ID int64
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
CreatedAt time.Time
UpdatedAt time.Time
}
func (p *Proxy) IsActive() bool {
return p.Status == StatusActive
}
func (p *Proxy) URL() string {
if p.Username != "" && p.Password != "" {
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
}
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
}
type ProxyWithAccountCount struct {
Proxy
AccountCount int64
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -14,15 +13,15 @@ var (
)
type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error
Create(ctx context.Context, proxy *Proxy) error
GetByID(ctx context.Context, id int64) (*Proxy, error)
Update(ctx context.Context, proxy *Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService {
}
// Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) {
// 创建代理
proxy := &model.Proxy{
proxy := &Proxy{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: model.StatusActive,
Status: StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
}
// GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get proxy: %w", err)
@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)
@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar
}
// ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err)
@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
}
// Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get proxy: %w", err)

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// RateLimitService 处理限流和过载状态管理
@@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) {
@@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
@@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
@@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
@@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" {
return

View File

@@ -0,0 +1,41 @@
package service
import (
"crypto/rand"
"encoding/hex"
"time"
)
type RedeemCode struct {
ID int64
Code string
Type string
Value float64
Status string
UsedBy *int64
UsedAt *time.Time
Notes string
CreatedAt time.Time
GroupID *int64
ValidityDays int
User *User
Group *Group
}
func (r *RedeemCode) IsUsed() bool {
return r.Status == StatusUsed
}
func (r *RedeemCode) CanUse() bool {
return r.Status == StatusUnused
}
func GenerateRedeemCode() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}

View File

@@ -10,7 +10,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
)
@@ -39,17 +38,17 @@ type RedeemCache interface {
}
type RedeemCodeRepository interface {
Create(ctx context.Context, code *model.RedeemCode) error
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
Update(ctx context.Context, code *model.RedeemCode) error
Create(ctx context.Context, code *RedeemCode) error
CreateBatch(ctx context.Context, codes []RedeemCode) error
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
Update(ctx context.Context, code *RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
}
// GenerateCodesRequest 生成兑换码请求
@@ -116,7 +115,7 @@ func (s *RedeemService) GenerateRandomCode() (string, error) {
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
@@ -131,21 +130,21 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
codeType := req.Type
if codeType == "" {
codeType = model.RedeemTypeBalance
codeType = RedeemTypeBalance
}
codes := make([]model.RedeemCode, 0, req.Count)
codes := make([]RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, model.RedeemCode{
codes = append(codes, RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Status: model.StatusUnused,
Status: StatusUnused,
})
}
@@ -210,7 +209,7 @@ func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
@@ -239,7 +238,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
}
@@ -261,7 +260,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case model.RedeemTypeBalance:
case RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
@@ -275,13 +274,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}()
}
case model.RedeemTypeConcurrency:
case RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case model.RedeemTypeSubscription:
case RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
@@ -320,7 +319,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
@@ -329,7 +328,7 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
return nil, fmt.Errorf("get redeem code: %w", err)
@@ -338,7 +337,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
@@ -383,7 +382,7 @@ func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)

View File

@@ -0,0 +1,10 @@
package service
import "time"
type Setting struct {
ID int64
Key string
Value string
UpdatedAt time.Time
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
@@ -19,7 +18,7 @@ var (
)
type SettingRepository interface {
Get(ctx context.Context, key string) (*model.Setting, error)
Get(ctx context.Context, key string) (*Setting, error)
GetValue(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
@@ -43,7 +42,7 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("get all settings: %w", err)
@@ -53,18 +52,18 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSetti
}
// GetPublicSettings 获取公开设置(无需登录)
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) {
keys := []string{
model.SettingKeyRegistrationEnabled,
model.SettingKeyEmailVerifyEnabled,
model.SettingKeyTurnstileEnabled,
model.SettingKeyTurnstileSiteKey,
model.SettingKeySiteName,
model.SettingKeySiteLogo,
model.SettingKeySiteSubtitle,
model.SettingKeyApiBaseUrl,
model.SettingKeyContactInfo,
model.SettingKeyDocUrl,
SettingKeyRegistrationEnabled,
SettingKeyEmailVerifyEnabled,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyContactInfo,
SettingKeyDocUrl,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -72,64 +71,64 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSe
return nil, fmt.Errorf("get public settings: %w", err)
}
return &model.PublicSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
DocUrl: settings[model.SettingKeyDocUrl],
return &PublicSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates := make(map[string]string)
// 注册设置
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[model.SettingKeySmtpHost] = settings.SmtpHost
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
if settings.TurnstileSecretKey != "" {
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
// OEM设置
updates[model.SettingKeySiteName] = settings.SiteName
updates[model.SettingKeySiteLogo] = settings.SiteLogo
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[model.SettingKeyContactInfo] = settings.ContactInfo
updates[model.SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
// 默认配置
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
return s.settingRepo.SetMultiple(ctx, updates)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err != nil {
// 默认开放注册
return true
@@ -139,7 +138,7 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
if err != nil {
return false
}
@@ -148,7 +147,7 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || value == "" {
return "Sub2API"
}
@@ -157,7 +156,7 @@ func (s *SettingService) GetSiteName(ctx context.Context) string {
// GetDefaultConcurrency 获取默认并发量
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultConcurrency)
if err != nil {
return s.cfg.Default.UserConcurrency
}
@@ -169,7 +168,7 @@ func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
// GetDefaultBalance 获取默认余额
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultBalance)
if err != nil {
return s.cfg.Default.UserBalance
}
@@ -182,7 +181,7 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
_, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
if err == nil {
// 已有设置,不需要初始化
return nil
@@ -193,62 +192,62 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
model.SettingKeyRegistrationEnabled: "true",
model.SettingKeyEmailVerifyEnabled: "false",
model.SettingKeySiteName: "Sub2API",
model.SettingKeySiteLogo: "",
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
model.SettingKeySmtpPort: "587",
model.SettingKeySmtpUseTLS: "false",
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
}
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
result := &model.SystemSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[model.SettingKeySmtpHost],
SmtpUsername: settings[model.SettingKeySmtpUsername],
SmtpFrom: settings[model.SettingKeySmtpFrom],
SmtpFromName: settings[model.SettingKeySmtpFromName],
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
DocUrl: settings[model.SettingKeyDocUrl],
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
result.DefaultConcurrency = concurrency
} else {
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
return result
}
@@ -263,7 +262,7 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled)
if err != nil {
return false
}
@@ -272,7 +271,7 @@ func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileSecretKey)
if err != nil {
return ""
}
@@ -287,10 +286,10 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
@@ -300,7 +299,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -324,7 +323,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -336,5 +335,5 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
}

View File

@@ -0,0 +1,42 @@
package service
type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
TurnstileSecretKey string
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
ContactInfo string
DocUrl string
DefaultConcurrency int
DefaultBalance float64
}
type PublicSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
TurnstileEnabled bool
TurnstileSiteKey string
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
ContactInfo string
DocUrl string
Version string
}

View File

@@ -7,7 +7,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -48,7 +47,7 @@ type AssignSubscriptionInput struct {
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
@@ -91,7 +90,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
//
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
@@ -132,8 +131,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
if existingSub.Status != SubscriptionStatusActive {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
@@ -185,19 +184,19 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
now := time.Now()
sub := &model.UserSubscription{
sub := &UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
Status: SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
@@ -229,14 +228,14 @@ type BulkAssignSubscriptionInput struct {
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Subscriptions []UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Subscriptions: make([]UserSubscription, 0),
Errors: make([]string, 0),
}
@@ -286,7 +285,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
@@ -299,8 +298,8 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 如果订阅已过期恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
if sub.Status == SubscriptionStatusExpired {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil {
return nil, err
}
}
@@ -319,12 +318,12 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) {
return s.userSubRepo.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
@@ -333,7 +332,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
if err != nil {
return nil, err
@@ -343,7 +342,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
@@ -353,7 +352,7 @@ func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, u
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
@@ -364,7 +363,7 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
if err != nil {
@@ -376,7 +375,7 @@ func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, user
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
func normalizeExpiredWindows(subs []model.UserSubscription) {
func normalizeExpiredWindows(subs []UserSubscription) {
for i := range subs {
sub := &subs[i]
// 日窗口过期:清零展示数据
@@ -403,7 +402,7 @@ func startOfDay(t time.Time) time.Time {
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
@@ -414,7 +413,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
// 使用当天零点作为新窗口起始时间
windowStart := startOfDay(time.Now())
needsInvalidateCache := false
@@ -458,7 +457,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
@@ -620,16 +619,16 @@ func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (i
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
if sub.Status == model.SubscriptionStatusExpired {
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
if sub.Status == SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == model.SubscriptionStatusSuspended {
if sub.Status == SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil

View File

@@ -8,7 +8,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefreshService OAuth token自动刷新服务
@@ -142,19 +141,19 @@ func (s *TokenRefreshService) processRefresh() {
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token包括临时禁用的
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功更新账号credentials
account.Credentials = model.JSONB(newCredentials)
account.Credentials = newCredentials
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}

View File

@@ -4,22 +4,20 @@ import (
"context"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台Anthropic/OpenAI/Gemini
type TokenRefresher interface {
// CanRefresh 检查此刷新器是否能处理指定账号
CanRefresh(account *model.Account) bool
CanRefresh(account *Account) bool
// NeedsRefresh 检查账号的token是否需要刷新
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
NeedsRefresh(account *Account, refreshWindow time.Duration) bool
// Refresh 执行token刷新返回更新后的credentials
// 注意返回的map应该保留原有credentials中的所有字段只更新token相关字段
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
Refresh(ctx context.Context, account *Account) (map[string]any, error)
}
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
@@ -37,14 +35,14 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth但有效期1年不需要频繁刷新
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformAnthropic &&
account.Type == model.AccountTypeOAuth
func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAnthropic &&
account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
@@ -61,7 +59,7 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段只更新token相关字段
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
@@ -103,14 +101,14 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAIToke
// CanRefresh 检查是否能处理此账号
// 只处理 openai 平台的 oauth 类型账号
func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformOpenAI &&
account.Type == model.AccountTypeOAuth
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformOpenAI &&
account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetOpenAITokenExpiresAt()
if expiresAt == nil {
return false
@@ -121,7 +119,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
// Refresh 执行token刷新
// 保留原有credentials中的所有字段只更新token相关字段
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err

View File

@@ -0,0 +1,53 @@
package service
import "time"
const (
BillingTypeBalance int8 = 0 // 钱包余额
BillingTypeSubscription int8 = 1 // 订阅套餐
)
type UsageLog struct {
ID int64
UserID int64
ApiKeyID int64
AccountID int64
RequestID string
Model string
GroupID *int64
SubscriptionID *int64
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
InputCost float64
OutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64
RateMultiplier float64
BillingType int8
Stream bool
DurationMs *int
FirstTokenMs *int
CreatedAt time.Time
User *User
ApiKey *ApiKey
Account *Account
Group *Group
Subscription *UserSubscription
}
func (u *UsageLog) TotalTokens() int {
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
}

View File

@@ -6,7 +6,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
@@ -66,7 +65,7 @@ func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *Usa
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
@@ -74,7 +73,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// 创建使用日志
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
@@ -112,7 +111,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get usage log: %w", err)
@@ -121,7 +120,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -130,7 +129,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -139,7 +138,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -243,7 +242,7 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
}
// calculateStats 计算统计数据
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats {
stats := &UsageStats{}
for _, log := range logs {
@@ -313,7 +312,7 @@ func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs [
}
// ListWithFilters lists usage logs with admin filters.
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) {
logs, result, err := s.usageRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs with filters: %w", err)

View File

@@ -0,0 +1,63 @@
package service
import (
"time"
"golang.org/x/crypto/bcrypt"
)
type User struct {
ID int64
Email string
Username string
Wechat string
Notes string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
CreatedAt time.Time
UpdatedAt time.Time
ApiKeys []ApiKey
Subscriptions []UserSubscription
}
func (u *User) IsAdmin() bool {
return u.Role == RoleAdmin
}
func (u *User) IsActive() bool {
return u.Status == StatusActive
}
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
return !isExclusive
}
func (u *User) SetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.PasswordHash = string(hash)
return nil
}
func (u *User) CheckPassword(password string) bool {
return bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil
}

View File

@@ -5,9 +5,7 @@ import (
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt"
)
var (
@@ -17,15 +15,15 @@ var (
)
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
GetByID(ctx context.Context, id int64) (*model.User, error)
GetByEmail(ctx context.Context, email string) (*model.User, error)
GetFirstAdmin(ctx context.Context) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error)
GetByEmail(ctx context.Context, email string) (*User, error)
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -61,7 +59,7 @@ func NewUserService(userRepo UserRepository) *UserService {
}
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) {
admin, err := s.userRepo.GetFirstAdmin(ctx)
if err != nil {
return nil, fmt.Errorf("get first admin: %w", err)
@@ -70,7 +68,7 @@ func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
@@ -79,7 +77,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
@@ -125,18 +123,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
if !user.CheckPassword(req.CurrentPassword) {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
if err := user.SetPassword(req.NewPassword); err != nil {
return fmt.Errorf("set password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
@@ -145,7 +139,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
}
// GetByID 根据ID获取用户管理员功能
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
@@ -154,7 +148,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)

View File

@@ -0,0 +1,124 @@
package service
import "time"
type UserSubscription struct {
ID int64
UserID int64
GroupID int64
StartsAt time.Time
ExpiresAt time.Time
Status string
DailyWindowStart *time.Time
WeeklyWindowStart *time.Time
MonthlyWindowStart *time.Time
DailyUsageUSD float64
WeeklyUsageUSD float64
MonthlyUsageUSD float64
AssignedBy *int64
AssignedAt time.Time
Notes string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
AssignedByUser *User
}
func (s *UserSubscription) IsActive() bool {
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
}
func (s *UserSubscription) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
func (s *UserSubscription) DaysRemaining() int {
if s.IsExpired() {
return 0
}
return int(time.Until(s.ExpiresAt).Hours() / 24)
}
func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
func (s *UserSubscription) NeedsDailyReset() bool {
if s.DailyWindowStart == nil {
return false
}
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
}
func (s *UserSubscription) NeedsWeeklyReset() bool {
if s.WeeklyWindowStart == nil {
return false
}
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
}
func (s *UserSubscription) NeedsMonthlyReset() bool {
if s.MonthlyWindowStart == nil {
return false
}
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
}
func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
func (s *UserSubscription) WeeklyResetTime() *time.Time {
if s.WeeklyWindowStart == nil {
return nil
}
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
return &t
}
func (s *UserSubscription) MonthlyResetTime() *time.Time {
if s.MonthlyWindowStart == nil {
return nil
}
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
return &t
}
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
if !group.HasDailyLimit() {
return true
}
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
}
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
if !group.HasWeeklyLimit() {
return true
}
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
}
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
if !group.HasMonthlyLimit() {
return true
}
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
}
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
daily = s.CheckDailyLimit(group, additionalCost)
weekly = s.CheckWeeklyLimit(group, additionalCost)
monthly = s.CheckMonthlyLimit(group, additionalCost)
return
}

View File

@@ -4,22 +4,21 @@ import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type UserSubscriptionRepository interface {
Create(ctx context.Context, sub *model.UserSubscription) error
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
Update(ctx context.Context, sub *model.UserSubscription) error
Create(ctx context.Context, sub *UserSubscription) error
GetByID(ctx context.Context, id int64) (*UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
Update(ctx context.Context, sub *UserSubscription) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error

View File

@@ -10,10 +10,10 @@ import (
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
"gorm.io/driver/postgres"
"gorm.io/gorm"
@@ -271,8 +271,7 @@ func initializeDatabase(cfg *SetupConfig) error {
}
}()
// 使用 model 包的 AutoMigrate,确保模型定义统一
return model.AutoMigrate(db)
return repository.AutoMigrate(db)
}
func createAdminUser(cfg *SetupConfig) error {
@@ -299,29 +298,28 @@ func createAdminUser(cfg *SetupConfig) error {
// Check if admin already exists
var count int64
db.Model(&model.User{}).Where("role = ?", "admin").Count(&count)
if err := db.Table("users").Where("role = ?", service.RoleAdmin).Count(&count).Error; err != nil {
return err
}
if count > 0 {
return nil // Admin already exists
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(cfg.Admin.Password), bcrypt.DefaultCost)
if err != nil {
admin := &service.User{
Email: cfg.Admin.Email,
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 0,
Concurrency: 5,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := admin.SetPassword(cfg.Admin.Password); err != nil {
return err
}
// Create admin user
admin := &model.User{
Email: cfg.Admin.Email,
PasswordHash: string(hashedPassword),
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
return db.Create(admin).Error
return repository.NewUserRepository(db).Create(context.Background(), admin)
}
func writeConfigFile(cfg *SetupConfig) error {