refactor: 自定义业务错误

This commit is contained in:
Forest
2025-12-25 20:52:47 +08:00
parent f51ad2e126
commit eeaff85e47
60 changed files with 1222 additions and 622 deletions

View File

@@ -19,14 +19,16 @@ linters:
files: files:
- "**/internal/service/**" - "**/internal/service/**"
deny: deny:
- pkg: sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "service must not import repository" desc: "service must not import repository"
- pkg: gorm.io/gorm
desc: "service must not import gorm"
handler-no-repository: handler-no-repository:
list-mode: original list-mode: original
files: files:
- "**/internal/handler/**" - "**/internal/handler/**"
deny: deny:
- pkg: sub2api/internal/repository - pkg: github.com/Wei-Shaw/sub2api/internal/repository
desc: "handler must not import repository" desc: "handler must not import repository"
errcheck: errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.

View File

@@ -117,7 +117,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -156,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
account, err := h.adminService.GetAccount(c.Request.Context(), accountID) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.NotFound(c, "Account not found") response.ErrorFrom(c, err)
return return
} }
@@ -184,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -218,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -236,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteAccount(c.Request.Context(), accountID) err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -297,7 +297,7 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies, SyncProxies: syncProxies,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -332,7 +332,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use OpenAI OAuth service to refresh token // Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -349,7 +349,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use Anthropic/Claude OAuth service to refresh token // Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -372,7 +372,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
Credentials: newCredentials, Credentials: newCredentials,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -403,7 +403,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime) stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -421,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID) account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -570,7 +570,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Extra: req.Extra, Extra: req.Extra,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -595,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID) result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -613,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID) result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -642,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -664,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -692,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
Scope: "full", Scope: "full",
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -714,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
Scope: "inference", Scope: "inference",
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -732,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID) usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -750,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID) err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -768,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID) stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -797,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable) account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) {
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
group, err := h.adminService.GetGroup(c.Request.Context(), groupID) group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil { if err != nil {
response.NotFound(c, "Group not found") response.ErrorFrom(c, err)
return return
} }
@@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD, MonthlyLimitUSD: req.MonthlyLimitUSD,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteGroup(c.Request.Context(), groupID) err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize) keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -40,7 +40,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -71,7 +71,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -103,7 +103,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -122,7 +122,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Get account // Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID) account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil { if err != nil {
response.NotFound(c, "Account not found") response.ErrorFrom(c, err)
return return
} }
@@ -141,7 +141,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Use OpenAI OAuth service to refresh token // Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil { if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -159,7 +159,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
Credentials: newCredentials, Credentials: newCredentials,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -192,7 +192,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
ProxyID: req.ProxyID, ProxyID: req.ProxyID,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -220,7 +220,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) {
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
if withCount { if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, proxies) response.Success(c, proxies)
@@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
proxies, err := h.adminService.GetAllProxies(c.Request.Context()) proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.NotFound(c, "Proxy not found") response.ErrorFrom(c, err)
return return
} }
@@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
Password: strings.TrimSpace(req.Password), Password: strings.TrimSpace(req.Password),
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -229,7 +229,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -272,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Check for duplicates (same host, port, username, password) // Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil { if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) {
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID) code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.NotFound(c, "Redeem code not found") response.ErrorFrom(c, err)
return return
} }
@@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
ValidityDays: req.ValidityDays, ValidityDays: req.ValidityDays,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID) err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs) deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID) code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
// Get all codes without pagination (use large page size) // Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "") codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil { if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
func (h *SettingHandler) GetSettings(c *gin.Context) { func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context()) settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -111,14 +111,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
// 重新获取设置返回 // 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -166,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
err := h.emailService.TestSmtpConnectionWithConfig(config) err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil { if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -252,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
` `
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -264,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) { func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context()) maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -279,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) { func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context()) key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -292,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
// DELETE /api/v1/admin/settings/admin-api-key // DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) { func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil { if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status) subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID) subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil { if err != nil {
response.NotFound(c, "Subscription not found") response.ErrorFrom(c, err)
return return
} }
@@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
Notes: req.Notes, Notes: req.Notes,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil { if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID) err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize) subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -90,7 +90,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -158,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 { if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, stats) response.Success(c, stats)
@@ -168,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if userID > 0 { if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime) stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
response.Success(c, stats) response.Success(c, stats)
@@ -178,7 +178,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
// Get global stats // Get global stats
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime) stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -197,7 +197,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
// Limit to 30 results // Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword) users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -236,7 +236,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30) keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil { if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -64,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) {
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search) users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -82,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
user, err := h.adminService.GetUser(c.Request.Context(), userID) user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil { if err != nil {
response.NotFound(c, "User not found") response.ErrorFrom(c, err)
return return
} }
@@ -109,7 +109,7 @@ func (h *UserHandler) Create(c *gin.Context) {
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
}) })
if err != nil { if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -144,7 +144,7 @@ func (h *UserHandler) Update(c *gin.Context) {
AllowedGroups: req.AllowedGroups, AllowedGroups: req.AllowedGroups,
}) })
if err != nil { if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -162,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteUser(c.Request.Context(), userID) err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -186,7 +186,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -206,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -226,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period) stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID) key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil { if err != nil {
response.NotFound(c, "API key not found") response.ErrorFrom(c, err)
return return
} }
@@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
} }
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil { if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID) groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" { if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
} }
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode) token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil { if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
// Turnstile 验证 // Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) {
// Turnstile 验证 // Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password) token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil { if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) {
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
func (h *SettingHandler) GetPublicSettings(c *gin.Context) { func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context()) settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil { if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// Get all active subscriptions with progress // Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
// Get all active subscriptions // Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -55,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) {
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation // [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil { if err != nil {
response.NotFound(c, "API key not found") response.ErrorFrom(c, err)
return return
} }
if apiKey.UserID != user.ID { if apiKey.UserID != user.ID {
@@ -77,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -107,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
record, err := h.usageService.GetByID(c.Request.Context(), usageID) record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil { if err != nil {
response.NotFound(c, "Usage record not found") response.ErrorFrom(c, err)
return return
} }
@@ -204,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
} }
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -259,7 +259,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get dashboard statistics") response.ErrorFrom(c, err)
return return
} }
@@ -286,7 +286,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get usage trend") response.ErrorFrom(c, err)
return return
} }
@@ -317,7 +317,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get model statistics") response.ErrorFrom(c, err)
return return
} }
@@ -362,7 +362,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
// Verify ownership of all requested API keys // Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil { if err != nil {
response.InternalError(c, "Failed to verify API key ownership") response.ErrorFrom(c, err)
return return
} }
@@ -386,7 +386,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get API key usage stats") response.ErrorFrom(c, err)
return return
} }

View File

@@ -49,7 +49,7 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
userData, err := h.userService.GetByID(c.Request.Context(), user.ID) userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil { if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -86,7 +86,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
} }
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error()) response.ErrorFrom(c, err)
return return
} }
@@ -120,7 +120,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
if err != nil { if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error()) response.ErrorFrom(c, err)
return return
} }

View File

@@ -0,0 +1,157 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, err.Error()).WithCause(err)
}

View File

@@ -0,0 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: "boom",
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: "wrap: EOF",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -0,0 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -0,0 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -3,9 +3,11 @@ package middleware
import ( import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -96,7 +98,7 @@ func validateJWTForAdmin(
// 验证 JWT token // 验证 JWT token
claims, err := authService.ValidateToken(token) claims, err := authService.ValidateToken(token)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false return false
} }

View File

@@ -2,9 +2,11 @@ package middleware
import ( import (
"context" "context"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
// 验证token // 验证token
claims, err := authService.ValidateToken(tokenString) claims, err := authService.ValidateToken(tokenString)
if err != nil { if err != nil {
if err == service.ErrTokenExpired { if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return return
} }

View File

@@ -4,14 +4,17 @@ import (
"math" "math"
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// Response 标准API响应格式 // Response 标准API响应格式
type Response struct { type Response struct {
Code int `json:"code"` Code int `json:"code"`
Message string `json:"message"` Message string `json:"message"`
Data any `json:"data,omitempty"` Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
} }
// PaginatedData 分页数据格式(匹配前端期望) // PaginatedData 分页数据格式(匹配前端期望)
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
// Error 返回错误响应 // Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) { func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{ c.JSON(statusCode, Response{
Code: statusCode, Code: statusCode,
Message: message, Message: message,
Reason: "",
Metadata: nil,
}) })
} }
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误 // BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) { func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message) Error(c, http.StatusBadRequest, message)

View File

@@ -0,0 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: "boom",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -13,23 +13,23 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
type AccountRepository struct { type accountRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewAccountRepository(db *gorm.DB) *AccountRepository { func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &AccountRepository{db: db} return &accountRepository{db: db}
} }
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error { func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error return r.db.WithContext(ctx).Create(account).Error
} }
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
} }
// 填充 GroupIDs 和 Groups 虚拟字段 // 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
@@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return &account, nil return &account, nil
} }
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" { if crsAccountID == "" {
return nil, nil return nil, nil
} }
@@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &account, nil return &account, nil
} }
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error { func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error return r.db.WithContext(ctx).Save(account).Error
} }
func (r *AccountRepository) Delete(ctx context.Context, id int64) error { func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系 // 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
@@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
} }
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "") return r.ListWithFilters(ctx, params, "", "", "", "")
} }
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account var accounts []model.Account
var total int64 var total int64
@@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
}, nil }, nil
} }
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
@@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
return accounts, err return accounts, err
} }
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).
@@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
return accounts, err return accounts, err
} }
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error { func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
} }
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"status": model.StatusError, "status": model.StatusError,
@@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
}).Error }).Error
} }
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{ ag := &model.AccountGroup{
AccountID: accountID, AccountID: accountID,
GroupID: groupID, GroupID: groupID,
@@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error return r.db.WithContext(ctx).Create(ag).Error
} }
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error Delete(&model.AccountGroup{}).Error
} }
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id"). Joins("JOIN account_groups ON account_groups.group_id = groups.id").
@@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
return groups, err return groups, err
} }
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive). Where("platform = ? AND status = ?", platform, model.StatusActive).
@@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
return accounts, err return accounts, err
} }
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定 // 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil { if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err return err
@@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
} }
// ListSchedulable 获取所有可调度的账号 // ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
} }
// ListSchedulableByGroupID 按组获取可调度的账号 // ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
} }
// ListSchedulableByPlatform 按平台获取可调度的账号 // ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
} }
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 // ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account var accounts []model.Account
now := time.Now() now := time.Now()
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
@@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
} }
// SetRateLimited 标记账号为限流状态(429) // SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now() now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
} }
// SetOverloaded 标记账号为过载状态(529) // SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error Update("overload_until", until).Error
} }
// ClearRateLimit 清除账号的限流状态 // ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error { func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
"rate_limited_at": nil, "rate_limited_at": nil,
@@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
} }
// UpdateSessionWindow 更新账号的5小时时间窗口信息 // UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{ updates := map[string]any{
"session_window_status": status, "session_window_status": status,
} }
@@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
} }
// SetSchedulable 设置账号的调度开关 // SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error Update("schedulable", schedulable).Error
} }
// UpdateExtra updates specific fields in account's Extra JSONB field // UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields // It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 { if len(updates) == 0 {
return nil return nil
} }
@@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
// BulkUpdate updates multiple accounts with the provided fields. // BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them. // It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 { if len(ids) == 0 {
return 0, nil return 0, nil
} }

View File

@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *AccountRepository repo *accountRepository
} }
func (s *AccountRepoSuite) SetupTest() { func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db) s.repo = NewAccountRepository(s.db).(*accountRepository)
} }
func TestAccountRepoSuite(t *testing.T) { func TestAccountRepoSuite(t *testing.T) {
@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Run(tt.name, func() { s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源 // 每个 case 重新获取隔离资源
db := testTx(s.T()) db := testTx(s.T())
repo := NewAccountRepository(db) repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background() ctx := context.Background()
tt.setup(db) tt.setup(db)

View File

@@ -2,51 +2,55 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ApiKeyRepository struct { type apiKeyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository { func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &ApiKeyRepository{db: db} return &apiKeyRepository{db: db}
} }
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error err := r.db.WithContext(ctx).Create(key).Error
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
} }
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &key, nil return &key, nil
} }
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
} }
return &apiKey, nil return &apiKey, nil
} }
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
} }
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
} }
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err return count, err
} }
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey var keys []model.ApiKey
var total int64 var total int64
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
} }
// SearchApiKeys searches API keys by user ID and/or keyword (name) // SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{}) db := r.db.WithContext(ctx).Model(&model.ApiKey{})
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
} }
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}). result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
Update("group_id", nil) Update("group_id", nil)
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
} }
// CountByGroupID 获取分组的 API Key 数量 // CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err

View File

@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ApiKeyRepository repo *apiKeyRepository
} }
func (s *ApiKeyRepoSuite) SetupTest() { func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db) s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
} }
func TestApiKeyRepoSuite(t *testing.T) { func TestApiKeyRepoSuite(t *testing.T) {

View File

@@ -0,0 +1,40 @@
package repository
import (
"errors"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return notFound.WithCause(err)
}
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
return err
}
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}

View File

@@ -2,47 +2,52 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type GroupRepository struct { type groupRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewGroupRepository(db *gorm.DB) *GroupRepository { func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &GroupRepository{db: db} return &groupRepository{db: db}
} }
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error { func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error err := r.db.WithContext(ctx).Create(group).Error
return translatePersistenceError(err, nil, service.ErrGroupExists)
} }
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
} }
return &group, nil return &group, nil
} }
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error { func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error return r.db.WithContext(ctx).Save(group).Error
} }
func (r *GroupRepository) Delete(ctx context.Context, id int64) error { func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
} }
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil) return r.ListWithFilters(ctx, params, "", "", nil)
} }
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group var groups []model.Group
var total int64 var total int64
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) { func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
return groups, nil return groups, nil
} }
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil { if err != nil {
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
return groups, nil return groups, nil
} }
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err return count, err
} }
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
// DB 返回底层数据库连接,用于事务处理 func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
func (r *GroupRepository) DB() *gorm.DB { group, err := r.GetByID(ctx, id)
return r.db if err != nil {
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return affectedUserIDs, nil
} }

View File

@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *GroupRepository repo *groupRepository
} }
func (s *GroupRepoSuite) SetupTest() { func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db) s.repo = NewGroupRepository(s.db).(*groupRepository)
} }
func TestGroupRepoSuite(t *testing.T) { func TestGroupRepoSuite(t *testing.T) {
@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID) count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count) s.Require().Zero(count)
} }
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}

View File

@@ -2,47 +2,50 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ProxyRepository struct { type proxyRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewProxyRepository(db *gorm.DB) *ProxyRepository { func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &ProxyRepository{db: db} return &proxyRepository{db: db}
} }
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error return r.db.WithContext(ctx).Create(proxy).Error
} }
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
} }
return &proxy, nil return &proxy, nil
} }
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error return r.db.WithContext(ctx).Save(proxy).Error
} }
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
} }
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query // ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy var proxies []model.Proxy
var total int64 var total int64
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
}, nil }, nil
} }
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err return proxies, err
} }
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}). err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
} }
// CountAccountsByProxyID returns the number of accounts using a specific proxy // CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}). err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID). Where("proxy_id = ?", proxyID).
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
} }
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct { type result struct {
ProxyID int64 `gorm:"column:proxy_id"` ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"` Count int64 `gorm:"column:count"`
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
} }
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy var proxies []model.Proxy
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive). Where("status = ?", model.StatusActive).

View File

@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *ProxyRepository repo *proxyRepository
} }
func (s *ProxyRepoSuite) SetupTest() { func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db) s.repo = NewProxyRepository(s.db).(*proxyRepository)
} }
func TestProxyRepoSuite(t *testing.T) { func TestProxyRepoSuite(t *testing.T) {

View File

@@ -2,57 +2,60 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type RedeemCodeRepository struct { type redeemCodeRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository { func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &RedeemCodeRepository{db: db} return &redeemCodeRepository{db: db}
} }
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error return r.db.WithContext(ctx).Create(code).Error
} }
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error return r.db.WithContext(ctx).Create(&codes).Error
} }
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &code, nil return &code, nil
} }
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
} }
return &redeemCode, nil return &redeemCode, nil
} }
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
} }
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query // ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
var total int64 var total int64
@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
}, nil }, nil
} }
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error return r.db.WithContext(ctx).Save(code).Error
} }
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error { func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now() now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused). Where("id = ? AND status = ?", id, model.StatusUnused).
@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用 return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
} }
return nil return nil
} }
// ListByUser returns all redeem codes used by a specific user // ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode var codes []model.RedeemCode
if limit <= 0 { if limit <= 0 {
limit = 10 limit = 10

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *RedeemCodeRepository repo *redeemCodeRepository
} }
func (s *RedeemCodeRepoSuite) SetupTest() { func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db) s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
} }
func TestRedeemCodeRepoSuite(t *testing.T) { func TestRedeemCodeRepoSuite(t *testing.T) {
@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
// Second use should fail // Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID) err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
err := s.repo.Use(s.ctx, code.ID, user.ID) err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code") s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
} }
// --- ListByUser --- // --- ListByUser ---
@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use") s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID) err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call") s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA") codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode") s.Require().NoError(err, "GetByCode")

View File

@@ -1,14 +1,16 @@
package repository package repository
import "github.com/Wei-Shaw/sub2api/internal/service"
// Repositories 所有仓库的集合 // Repositories 所有仓库的集合
type Repositories struct { type Repositories struct {
User *UserRepository User service.UserRepository
ApiKey *ApiKeyRepository ApiKey service.ApiKeyRepository
Group *GroupRepository Group service.GroupRepository
Account *AccountRepository Account service.AccountRepository
Proxy *ProxyRepository Proxy service.ProxyRepository
RedeemCode *RedeemCodeRepository RedeemCode service.RedeemCodeRepository
UsageLog *UsageLogRepository UsageLog service.UsageLogRepository
Setting *SettingRepository Setting service.SettingRepository
UserSubscription *UserSubscriptionRepository UserSubscription service.UserSubscriptionRepository
} }

View File

@@ -2,35 +2,38 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
// SettingRepository 系统设置数据访问层 // SettingRepository 系统设置数据访问层
type SettingRepository struct { type settingRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewSettingRepository 创建系统设置仓库实例 // NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) *SettingRepository { func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &SettingRepository{db: db} return &settingRepository{db: db}
} }
// Get 根据Key获取设置值 // Get 根据Key获取设置值
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
} }
return &setting, nil return &setting, nil
} }
// GetValue 获取设置值字符串 // GetValue 获取设置值字符串
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) { func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key) setting, err := r.Get(ctx, key)
if err != nil { if err != nil {
return "", err return "", err
@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
} }
// Set 设置值(存在则更新,不存在则创建) // Set 设置值(存在则更新,不存在则创建)
func (r *SettingRepository) Set(ctx context.Context, key, value string) error { func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{ setting := &model.Setting{
Key: key, Key: key,
Value: value, Value: value,
@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
} }
// GetMultiple 批量获取设置 // GetMultiple 批量获取设置
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil { if err != nil {
@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
} }
// SetMultiple 批量设置值 // SetMultiple 批量设置值
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings { for key, value := range settings {
setting := &model.Setting{ setting := &model.Setting{
@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
} }
// GetAll 获取所有设置 // GetAll 获取所有设置
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) { func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting var settings []model.Setting
err := r.db.WithContext(ctx).Find(&settings).Error err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil { if err != nil {
@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
} }
// Delete 删除设置 // Delete 删除设置
func (r *SettingRepository) Delete(ctx context.Context, key string) error { func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *SettingRepository repo *settingRepository
} }
func (s *SettingRepoSuite) SetupTest() { func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db) s.repo = NewSettingRepository(s.db).(*settingRepository)
} }
func TestSettingRepoSuite(t *testing.T) { func TestSettingRepoSuite(t *testing.T) {
@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
func (s *SettingRepoSuite) TestGetValue_Missing() { func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent") _, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key") s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() { func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete") s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete") _, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete") s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrSettingNotFound)
} }
func (s *SettingRepoSuite) TestDelete_Idempotent() { func (s *SettingRepoSuite) TestDelete_Idempotent() {

View File

@@ -2,25 +2,28 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UsageLogRepository struct { type usageLogRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository { func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository {
return &UsageLogRepository{db: db} return &usageLogRepository{db: db}
} }
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤 // getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) { func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute) fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct { var perfStats struct {
RequestCount int64 `gorm:"column:request_count"` RequestCount int64 `gorm:"column:request_count"`
@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5 return perfStats.RequestCount / 5, perfStats.TokenCount / 5
} }
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog var log model.UsageLog
err := r.db.WithContext(ctx).First(&log, id).Error err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
} }
return &log, nil return &log, nil
} }
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil }, nil
} }
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -120,7 +123,7 @@ type UserStats struct {
CacheReadTokens int64 `json:"cache_read_tokens"` CacheReadTokens int64 `json:"cache_read_tokens"`
} }
func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(` Select(`
@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
// DashboardStats 仪表盘统计 // DashboardStats 仪表盘统计
type DashboardStats = usagestats.DashboardStats type DashboardStats = usagestats.DashboardStats
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats var stats DashboardStats
today := timezone.Today() today := timezone.Today()
@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil return &stats, nil
} }
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil }, nil
} }
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
return logs, nil, err return logs, nil, err
} }
func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
} }
// GetAccountTodayStats 获取账号今日统计 // GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today() today := timezone.Today()
var stats struct { var stats struct {
@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
} }
// GetAccountWindowStats 获取账号时间窗口内的统计 // GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct { var stats struct {
Requests int64 `gorm:"column:requests"` Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"` Tokens int64 `gorm:"column:tokens"`
@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date // GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
var results []ApiKeyUsageTrendPoint var results []ApiKeyUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
} }
// GetUserUsageTrend returns usage trend data grouped by user and date // GetUserUsageTrend returns usage trend data grouped by user and date
func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
var results []UserUsageTrendPoint var results []UserUsageTrendPoint
// Choose date format based on granularity // Choose date format based on granularity
@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
type UserDashboardStats = usagestats.UserDashboardStats type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计 // GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
var stats UserDashboardStats var stats UserDashboardStats
today := timezone.Today() today := timezone.Today()
@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
} }
// GetUserUsageTrendByUserID 获取指定用户的使用趋势 // GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
} }
// GetUserModelStats 获取指定用户的模型统计 // GetUserModelStats 获取指定用户的模型统计
func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}). err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin) // ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog var logs []model.UsageLog
var total int64 var total int64
@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
type BatchUserUsageStats = usagestats.BatchUserUsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users // GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
if len(userIDs) == 0 { if len(userIDs) == 0 {
return make(map[int64]*BatchUserUsageStats), nil return make(map[int64]*BatchUserUsageStats), nil
} }
@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
if len(apiKeyIDs) == 0 { if len(apiKeyIDs) == 0 {
return make(map[int64]*BatchApiKeyUsageStats), nil return make(map[int64]*BatchApiKeyUsageStats), nil
} }
@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
} }
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint var results []TrendDataPoint
var dateFormat string var dateFormat string
@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
} }
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters // GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}). db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
} }
// GetGlobalStats gets usage statistics for all users within a time range // GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct { var stats struct {
TotalRequests int64 `gorm:"column:total_requests"` TotalRequests int64 `gorm:"column:total_requests"`
TotalInputTokens int64 `gorm:"column:total_input_tokens"` TotalInputTokens int64 `gorm:"column:total_input_tokens"`
@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 { if daysCount <= 0 {
daysCount = 30 daysCount = 30

View File

@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UsageLogRepository repo *usageLogRepository
} }
func (s *UsageLogRepoSuite) SetupTest() { func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db) s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
} }
func TestUsageLogRepoSuite(t *testing.T) { func TestUsageLogRepoSuite(t *testing.T) {

View File

@@ -2,56 +2,61 @@ package repository
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UserRepository struct { type userRepository struct {
db *gorm.DB db *gorm.DB
} }
func NewUserRepository(db *gorm.DB) *UserRepository { func NewUserRepository(db *gorm.DB) service.UserRepository {
return &UserRepository{db: db} return &userRepository{db: db}
} }
func (r *UserRepository) Create(ctx context.Context, user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }
func (r *UserRepository) Update(ctx context.Context, user *model.User) error { func (r *userRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
func (r *UserRepository) Delete(ctx context.Context, id int64) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
} }
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "") return r.ListWithFilters(ctx, params, "", "", "")
} }
// ListWithFilters lists users with optional filtering by status, role, and search query // ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User var users []model.User
var total int64 var total int64
@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil }, nil
} }
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error Update("balance", gorm.Expr("balance + ?", amount)).Error
} }
// DeductBalance 扣减用户余额,仅当余额充足时执行 // DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount). Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount)) Update("balance", gorm.Expr("balance - ?", amount))
@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error return result.Error
} }
if result.RowsAffected == 0 { if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在 return service.ErrInsufficientBalance
} }
return nil return nil
} }
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
} }
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err return count > 0, err
@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数 // 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}). result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID). Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
} }
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User var user model.User
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC"). Order("id ASC").
First(&user).Error First(&user).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
return &user, nil return &user, nil
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gorm.io/gorm" "gorm.io/gorm"
@@ -18,13 +19,13 @@ type UserRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserRepository repo *userRepository
} }
func (s *UserRepoSuite) SetupTest() { func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserRepository(s.db) s.repo = NewUserRepository(s.db).(*userRepository)
} }
func TestUserRepoSuite(t *testing.T) { func TestUserRepoSuite(t *testing.T) {
@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
err := s.repo.DeductBalance(s.ctx, user.ID, 999) err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance") s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound) s.Require().ErrorIs(err, service.ErrInsufficientBalance)
} }
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
err = s.repo.DeductBalance(s.ctx, user1.ID, 999) err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance") s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error") s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID) got5, err := s.repo.GetByID(s.ctx, user1.ID)

View File

@@ -6,27 +6,29 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm" "gorm.io/gorm"
) )
// UserSubscriptionRepository 用户订阅仓库 // UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct { type userSubscriptionRepository struct {
db *gorm.DB db *gorm.DB
} }
// NewUserSubscriptionRepository 创建用户订阅仓库 // NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository { func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db} return &userSubscriptionRepository{db: db}
} }
// Create 创建订阅 // Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error err := r.db.WithContext(ctx).Create(sub).Error
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
} }
// GetByID 根据ID获取订阅 // GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("User"). Preload("User").
@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
Preload("AssignedByUser"). Preload("AssignedByUser").
First(&sub, id).Error First(&sub, id).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 // GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 // GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription var sub model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
userID, groupID, model.SubscriptionStatusActive, time.Now()). userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error First(&sub).Error
if err != nil { if err != nil {
return nil, err return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
return &sub, nil return &sub, nil
} }
// Update 更新订阅 // Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now() sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error return r.db.WithContext(ctx).Save(sub).Error
} }
// Delete 删除订阅 // Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error { func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
} }
// ListByUserID 获取用户的所有订阅 // ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
} }
// ListActiveByUserID 获取用户的所有有效订阅 // ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Preload("Group"). Preload("Group").
@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
} }
// ListByGroupID 获取分组的所有订阅(分页) // ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
} }
// List 获取所有订阅(分页,支持筛选) // List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
var total int64 var total int64
@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
} }
// IncrementUsage 增加使用量 // IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
} }
// ResetDailyUsage 重置日使用量 // ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
} }
// ResetWeeklyUsage 重置周使用量 // ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
} }
// ResetMonthlyUsage 重置月使用量 // ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
} }
// ActivateWindows 激活所有窗口(首次使用时) // ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
} }
// UpdateStatus 更新订阅状态 // UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
} }
// ExtendExpiry 延长订阅过期时间 // ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
} }
// UpdateNotes 更新订阅备注 // UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}). return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id). Where("id = ?", id).
Updates(map[string]any{ Updates(map[string]any{
@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
} }
// ListExpired 获取所有已过期但状态仍为active的订阅 // ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription var subs []model.UserSubscription
err := r.db.WithContext(ctx). err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
} }
// BatchUpdateExpiredStatus 批量更新过期订阅状态 // BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{ Updates(map[string]any{
@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
} }
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 // ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID). Where("user_id = ? AND group_id = ?", userID, groupID).
@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
} }
// CountByGroupID 获取分组的订阅数量 // CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID). Where("group_id = ?", groupID).
@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
} }
// CountActiveByGroupID 获取分组的有效订阅数量 // CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64 var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?", Where("group_id = ? AND status = ? AND expires_at > ?",
@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
} }
// DeleteByGroupID 删除分组相关的所有订阅记录 // DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
suite.Suite suite.Suite
ctx context.Context ctx context.Context
db *gorm.DB db *gorm.DB
repo *UserSubscriptionRepository repo *userSubscriptionRepository
} }
func (s *UserSubscriptionRepoSuite) SetupTest() { func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background() s.ctx = context.Background()
s.db = testTx(s.T()) s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db) s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
} }
func TestUserSubscriptionRepoSuite(t *testing.T) { func TestUserSubscriptionRepoSuite(t *testing.T) {

View File

@@ -1,7 +1,6 @@
package repository package repository
import ( import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
) )
@@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet(
NewClaudeOAuthClient, NewClaudeOAuthClient,
NewHTTPUpstream, NewHTTPUpstream,
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(service.UserRepository), new(*UserRepository)),
wire.Bind(new(service.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(service.GroupRepository), new(*GroupRepository)),
wire.Bind(new(service.AccountRepository), new(*AccountRepository)),
wire.Bind(new(service.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(service.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(service.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(service.SettingRepository), new(*SettingRepository)),
wire.Bind(new(service.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
) )

View File

@@ -2,17 +2,16 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrAccountNotFound = errors.New("account not found") ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
) )
type AccountRepository interface { type AccountRepository interface {
@@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
for _, groupID := range req.GroupIDs { for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
@@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
return account, nil return account, nil
@@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err) return nil, fmt.Errorf("get account: %w", err)
} }
@@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
for _, groupID := range *req.GroupIDs { for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID) _, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
} }
@@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在 // 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id) _, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
@@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
@@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) { func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err) return "", fmt.Errorf("get account: %w", err)
} }
@@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id) account, err := s.accountRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
// AdminService interface defines admin management operations // AdminService interface defines admin management operations
@@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
} }
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在 affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组先获取受影响的用户ID列表用于事务后失效缓存
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil任何类型的分组都需要
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,20 +9,20 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
) )
var ( var (
ErrApiKeyNotFound = errors.New("api key not found") ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists") ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
) )
const ( const (
@@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
if req.GroupID != nil { if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
return apiKey, nil return apiKey, nil
@@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
// 这里可以添加Redis缓存逻辑暂时直接查询数据库 // 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
@@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err) return nil, fmt.Errorf("get api key: %w", err)
} }
@@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
group, err := s.groupRepo.GetByID(ctx, *req.GroupID) group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id) apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err) return fmt.Errorf("get api key: %w", err)
} }
@@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// 检查API Key状态 // 检查API Key状态
if !apiKey.IsActive() { if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active") return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
} }
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID) user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err) return nil, nil, fmt.Errorf("get user: %w", err)
} }
@@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户的所有有效订阅 // 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err) return nil, fmt.Errorf("list active subscriptions: %w", err)
} }

View File

@@ -8,22 +8,22 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var ( var (
ErrInvalidCredentials = errors.New("invalid email or password") ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = errors.New("user is not active") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = errors.New("email already exists") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = errors.New("invalid token") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = errors.New("token has expired") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
) )
// JWTClaims JWT载荷数据 // JWTClaims JWT载荷数据
@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// 查找用户 // 查找用户
user, err := s.userRepo.GetByEmail(ctx, email) user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials return "", nil, ErrInvalidCredentials
} }
// 记录数据库错误但不暴露给用户 // 记录数据库错误但不暴露给用户
@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 获取最新的用户信息 // 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID) user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrUserNotFound) {
return "", ErrInvalidToken return "", ErrInvalidToken
} }
log.Printf("[Auth] Database error refreshing token: %v", err) log.Printf("[Auth] Database error refreshing token: %v", err)

View File

@@ -2,11 +2,11 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
@@ -14,7 +14,7 @@ import (
// 注ErrInsufficientBalance在redeem_service.go中定义 // 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 // 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var ( var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
) )
// subscriptionCacheData 订阅缓存数据结构(内部使用) // subscriptionCacheData 订阅缓存数据结构(内部使用)

View File

@@ -4,21 +4,21 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"net/smtp" "net/smtp"
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
) )
var ( var (
ErrEmailNotConfigured = errors.New("email service not configured") ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code") ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code") ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code") ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
) )
// EmailCache defines cache operations for email service // EmailCache defines cache operations for email service

View File

@@ -2,17 +2,16 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrGroupNotFound = errors.New("group not found") ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
ErrGroupExists = errors.New("group name already exists") ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
) )
type GroupRepository interface { type GroupRepository interface {
@@ -20,6 +19,7 @@ type GroupRepository interface {
GetByID(ctx context.Context, id int64) (*model.Group, error) GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
@@ -29,8 +29,6 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error) ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
} }
// CreateGroupRequest 创建分组请求 // CreateGroupRequest 创建分组请求
@@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
return group, nil return group, nil
@@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }
@@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在 // 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id) _, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err) return fmt.Errorf("get group: %w", err)
} }
@@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) { func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err) return nil, fmt.Errorf("get group: %w", err)
} }

View File

@@ -2,16 +2,15 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
) )
var ( var (
ErrProxyNotFound = errors.New("proxy not found") ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
) )
type ProxyRepository interface { type ProxyRepository interface {
@@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
} }
return proxy, nil return proxy, nil
@@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err) return nil, fmt.Errorf("get proxy: %w", err)
} }
@@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在 // 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id) _, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err) return fmt.Errorf("get proxy: %w", err)
} }
@@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err) return fmt.Errorf("get proxy: %w", err)
} }
@@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) { func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id) proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err) return "", fmt.Errorf("get proxy: %w", err)
} }

View File

@@ -9,19 +9,18 @@ import (
"strings" "strings"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gorm"
) )
var ( var (
ErrRedeemCodeNotFound = errors.New("redeem code not found") ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used") ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code") ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrInsufficientBalance = errors.New("insufficient balance") ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later") ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
) )
const ( const (
@@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 查找兑换码 // 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code) redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrRedeemCodeNotFound) {
s.incrementRedeemErrorCount(ctx, userID) s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound return nil, ErrRedeemCodeNotFound
} }
@@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 验证兑换码类型的前置条件 // 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil { if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id") return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
} }
// 获取用户信息 // 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
_ = user // 使用变量避免未使用错误 _ = user // 使用变量避免未使用错误
@@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 【关键】先标记兑换码为已使用,确保并发安全 // 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性 // 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
// 兑换码已被其他请求使用
return nil, ErrRedeemCodeUsed return nil, ErrRedeemCodeUsed
} }
return nil, fmt.Errorf("mark code as used: %w", err) return nil, fmt.Errorf("mark code as used: %w", err)
@@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id) code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err) return nil, fmt.Errorf("get redeem code: %w", err)
} }
return code, nil return code, nil
@@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code) redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err) return nil, fmt.Errorf("get redeem code: %w", err)
} }
return redeemCode, nil return redeemCode, nil
@@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在 // 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id) code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err) return fmt.Errorf("get redeem code: %w", err)
} }
// 不允许删除已使用的兑换码 // 不允许删除已使用的兑换码
if code.IsUsed() { if code.IsUsed() {
return errors.New("cannot delete used redeem code") return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
} }
if err := s.redeemRepo.Delete(ctx, id); err != nil { if err := s.redeemRepo.Delete(ctx, id); err != nil {

View File

@@ -9,13 +9,13 @@ import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
) )
var ( var (
ErrRegistrationDisabled = errors.New("registration is currently disabled") ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
) )
type SettingRepository interface { type SettingRepository interface {
@@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 已有设置,不需要初始化 // 已有设置,不需要初始化
return nil return nil
} }
if !errors.Is(err, gorm.ErrRecordNotFound) { if !errors.Is(err, ErrSettingNotFound) {
return fmt.Errorf("check existing settings: %w", err) return fmt.Errorf("check existing settings: %w", err)
} }
@@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", false, nil return "", false, nil
} }
return "", false, err return "", false, err
@@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串 return "", nil // 未配置,返回空字符串
} }
return "", err // 数据库错误 return "", err // 数据库错误

View File

@@ -2,24 +2,24 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
var ( var (
ErrSubscriptionNotFound = errors.New("subscription not found") ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired") ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended") ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group") ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type") ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded") ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded") ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
) )
// SubscriptionService 订阅服务 // SubscriptionService 订阅服务

View File

@@ -2,14 +2,15 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
) )
var ( var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed") ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured") ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
) )
// TurnstileVerifier 验证 Turnstile token 的接口 // TurnstileVerifier 验证 Turnstile token 的接口

View File

@@ -2,18 +2,17 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"gorm.io/gorm"
) )
var ( var (
ErrUsageLogNotFound = errors.New("usage log not found") ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found")
) )
// CreateUsageLogRequest 创建使用日志请求 // CreateUsageLogRequest 创建使用日志请求
@@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 验证用户存在 // 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID) _, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id) log, err := s.usageRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err) return nil, fmt.Errorf("get usage log: %w", err)
} }
return log, nil return log, nil

View File

@@ -2,19 +2,18 @@ package service
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
) )
type UserRepository interface { type UserRepository interface {
@@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) { func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
return user, nil return user, nil
@@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) { func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
@@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err) return fmt.Errorf("get user: %w", err)
} }
@@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) { func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id) user, err := s.userRepo.GetByID(ctx, id)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err) return nil, fmt.Errorf("get user: %w", err)
} }
return user, nil return user, nil
@@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err) return fmt.Errorf("get user: %w", err)
} }