refactor: 自定义业务错误

This commit is contained in:
Forest
2025-12-25 20:52:47 +08:00
parent f51ad2e126
commit eeaff85e47
60 changed files with 1222 additions and 622 deletions

View File

@@ -117,7 +117,7 @@ func (h *AccountHandler) List(c *gin.Context) {
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -156,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
response.ErrorFrom(c, err)
return
}
@@ -184,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -218,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -236,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -297,7 +297,7 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
SyncProxies: syncProxies,
})
if err != nil {
response.BadRequest(c, "Sync failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -332,7 +332,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -349,7 +349,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
// Use Anthropic/Claude OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -372,7 +372,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -403,7 +403,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) {
stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get account stats: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -421,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -570,7 +570,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
Extra: req.Extra,
})
if err != nil {
response.InternalError(c, "Failed to bulk update accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -595,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -613,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -642,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -664,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -692,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) {
Scope: "full",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -714,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
Scope: "inference",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -732,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -750,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -768,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -797,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) {
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
}
if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.NotFound(c, "Group not found")
response.ErrorFrom(c, err)
return
}
@@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -40,7 +40,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -71,7 +71,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -103,7 +103,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
if err != nil {
response.BadRequest(c, "Failed to refresh token: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -122,7 +122,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
response.ErrorFrom(c, err)
return
}
@@ -141,7 +141,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
// Use OpenAI OAuth service to refresh token
tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -159,7 +159,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -192,7 +192,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -220,7 +220,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to create account: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) {
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, proxies)
@@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.NotFound(c, "Proxy not found")
response.ErrorFrom(c, err)
return
}
@@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
Password: strings.TrimSpace(req.Password),
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
Status: strings.TrimSpace(req.Status),
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) {
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -229,7 +229,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -272,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) {
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) {
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.NotFound(c, "Redeem code not found")
response.ErrorFrom(c, err)
return
}
@@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
ValidityDays: req.ValidityDays,
})
if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -111,14 +111,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -166,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -252,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -264,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get admin API key status: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -279,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to generate admin API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -292,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.InternalError(c, "Failed to delete admin API key: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
response.ErrorFrom(c, err)
return
}
@@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) {
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -90,7 +90,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -158,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -168,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
response.Success(c, stats)
@@ -178,7 +178,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
// Get global stats
stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -197,7 +197,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -236,7 +236,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -64,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) {
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -82,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
response.ErrorFrom(c, err)
return
}
@@ -109,7 +109,7 @@ func (h *UserHandler) Create(c *gin.Context) {
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -144,7 +144,7 @@ func (h *UserHandler) Update(c *gin.Context) {
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -162,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) {
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -186,7 +186,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -206,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -226,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) {
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
@@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) {
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) {
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
response.ErrorFrom(c, err)
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) {
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) {
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) *
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) {
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -55,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) {
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
response.ErrorFrom(c, err)
return
}
if apiKey.UserID != user.ID {
@@ -77,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
}
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -107,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) {
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
response.ErrorFrom(c, err)
return
}
@@ -204,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -259,7 +259,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) {
stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
response.ErrorFrom(c, err)
return
}
@@ -286,7 +286,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) {
trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
response.ErrorFrom(c, err)
return
}
@@ -317,7 +317,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
response.ErrorFrom(c, err)
return
}
@@ -362,7 +362,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
response.ErrorFrom(c, err)
return
}
@@ -386,7 +386,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
response.ErrorFrom(c, err)
return
}

View File

@@ -49,7 +49,7 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -86,7 +86,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
response.ErrorFrom(c, err)
return
}
@@ -120,7 +120,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to update profile: "+err.Error())
response.ErrorFrom(c, err)
return
}

View File

@@ -0,0 +1,157 @@
package errors
import (
"errors"
"fmt"
"net/http"
)
const (
UnknownCode = http.StatusInternalServerError
UnknownReason = ""
)
type Status struct {
Code int32 `json:"code"`
Reason string `json:"reason,omitempty"`
Message string `json:"message"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ApplicationError is the standard error type used to control HTTP responses.
//
// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500).
type ApplicationError struct {
Status
cause error
}
// Error is kept for backwards compatibility within this package.
type Error = ApplicationError
func (e *ApplicationError) Error() string {
if e == nil {
return "<nil>"
}
if e.cause == nil {
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata)
}
return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause)
}
// Unwrap provides compatibility for Go 1.13 error chains.
func (e *ApplicationError) Unwrap() error { return e.cause }
// Is matches each error in the chain with the target value.
func (e *ApplicationError) Is(err error) bool {
if se := new(ApplicationError); errors.As(err, &se) {
return se.Code == e.Code && se.Reason == e.Reason
}
return false
}
// WithCause attaches the underlying cause of the error.
func (e *ApplicationError) WithCause(cause error) *ApplicationError {
err := Clone(e)
err.cause = cause
return err
}
// WithMetadata deep-copies the given metadata map.
func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError {
err := Clone(e)
if md == nil {
err.Metadata = nil
return err
}
err.Metadata = make(map[string]string, len(md))
for k, v := range md {
err.Metadata[k] = v
}
return err
}
// New returns an error object for the code, message.
func New(code int, reason, message string) *ApplicationError {
return &ApplicationError{
Status: Status{
Code: int32(code),
Message: message,
Reason: reason,
},
}
}
// Newf New(code fmt.Sprintf(format, a...))
func Newf(code int, reason, format string, a ...any) *ApplicationError {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Errorf returns an error object for the code, message and error info.
func Errorf(code int, reason, format string, a ...any) error {
return New(code, reason, fmt.Sprintf(format, a...))
}
// Code returns the http code for an error.
// It supports wrapped errors.
func Code(err error) int {
if err == nil {
return http.StatusOK
}
return int(FromError(err).Code)
}
// Reason returns the reason for a particular error.
// It supports wrapped errors.
func Reason(err error) string {
if err == nil {
return UnknownReason
}
return FromError(err).Reason
}
// Message returns the message for a particular error.
// It supports wrapped errors.
func Message(err error) string {
if err == nil {
return ""
}
return FromError(err).Message
}
// Clone deep clone error to a new error.
func Clone(err *ApplicationError) *ApplicationError {
if err == nil {
return nil
}
var metadata map[string]string
if err.Metadata != nil {
metadata = make(map[string]string, len(err.Metadata))
for k, v := range err.Metadata {
metadata[k] = v
}
}
return &ApplicationError{
cause: err.cause,
Status: Status{
Code: err.Code,
Reason: err.Reason,
Message: err.Message,
Metadata: metadata,
},
}
}
// FromError tries to convert an error to *ApplicationError.
// It supports wrapped errors.
func FromError(err error) *ApplicationError {
if err == nil {
return nil
}
if se := new(ApplicationError); errors.As(err, &se) {
return se
}
// Fall back to a generic internal error.
return New(UnknownCode, UnknownReason, err.Error()).WithCause(err)
}

View File

@@ -0,0 +1,168 @@
//go:build unit
package errors
import (
stderrors "errors"
"fmt"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestApplicationError_Basics(t *testing.T) {
tests := []struct {
name string
err *ApplicationError
want Status
wantIs bool
target error
wrapped error
}{
{
name: "new",
err: New(400, "BAD_REQUEST", "invalid input"),
want: Status{
Code: 400,
Reason: "BAD_REQUEST",
Message: "invalid input",
},
},
{
name: "is_matches_code_and_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "UNAUTHORIZED", "ignored message"),
wantIs: true,
},
{
name: "is_does_not_match_reason",
err: New(401, "UNAUTHORIZED", "nope"),
want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"},
target: New(401, "DIFFERENT", "ignored message"),
wantIs: false,
},
{
name: "from_error_unwraps_wrapped_application_error",
err: New(404, "NOT_FOUND", "missing"),
wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")),
want: Status{
Code: 404,
Reason: "NOT_FOUND",
Message: "missing",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.err != nil {
require.Equal(t, tt.want, tt.err.Status)
}
if tt.target != nil {
require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target))
}
if tt.wrapped != nil {
got := FromError(tt.wrapped)
require.Equal(t, tt.want, got.Status)
}
})
}
}
func TestApplicationError_WithMetadataDeepCopy(t *testing.T) {
tests := []struct {
name string
md map[string]string
}{
{name: "non_nil", md: map[string]string{"a": "1"}},
{name: "nil", md: nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md)
if tt.md == nil {
require.Nil(t, appErr.Metadata)
return
}
tt.md["a"] = "changed"
require.Equal(t, "1", appErr.Metadata["a"])
})
}
}
func TestFromError_Generic(t *testing.T) {
tests := []struct {
name string
err error
wantCode int32
wantReason string
wantMsg string
}{
{
name: "plain_error",
err: stderrors.New("boom"),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: "boom",
},
{
name: "wrapped_plain_error",
err: fmt.Errorf("wrap: %w", io.EOF),
wantCode: UnknownCode,
wantReason: UnknownReason,
wantMsg: "wrap: EOF",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FromError(tt.err)
require.Equal(t, tt.wantCode, got.Code)
require.Equal(t, tt.wantReason, got.Reason)
require.Equal(t, tt.wantMsg, got.Message)
require.Equal(t, tt.err, got.Unwrap())
})
}
}
func TestToHTTP(t *testing.T) {
tests := []struct {
name string
err error
wantStatusCode int
wantBody Status
}{
{
name: "nil_error",
err: nil,
wantStatusCode: http.StatusOK,
wantBody: Status{Code: int32(http.StatusOK)},
},
{
name: "application_error",
err: Forbidden("FORBIDDEN", "no access"),
wantStatusCode: http.StatusForbidden,
wantBody: Status{
Code: int32(http.StatusForbidden),
Reason: "FORBIDDEN",
Message: "no access",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, body := ToHTTP(tt.err)
require.Equal(t, tt.wantStatusCode, code)
require.Equal(t, tt.wantBody, body)
})
}
}

View File

@@ -0,0 +1,21 @@
package errors
import "net/http"
// ToHTTP converts an error into an HTTP status code and a JSON-serializable body.
//
// The returned body matches the project's Status shape:
// { code, reason, message, metadata }.
func ToHTTP(err error) (statusCode int, body Status) {
if err == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
appErr := FromError(err)
if appErr == nil {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
}

View File

@@ -0,0 +1,114 @@
// nolint:mnd
package errors
import "net/http"
// BadRequest new BadRequest error that is mapped to a 400 response.
func BadRequest(reason, message string) *ApplicationError {
return New(http.StatusBadRequest, reason, message)
}
// IsBadRequest determines if err is an error which indicates a BadRequest error.
// It supports wrapped errors.
func IsBadRequest(err error) bool {
return Code(err) == http.StatusBadRequest
}
// TooManyRequests new TooManyRequests error that is mapped to a 429 response.
func TooManyRequests(reason, message string) *ApplicationError {
return New(http.StatusTooManyRequests, reason, message)
}
// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error.
// It supports wrapped errors.
func IsTooManyRequests(err error) bool {
return Code(err) == http.StatusTooManyRequests
}
// Unauthorized new Unauthorized error that is mapped to a 401 response.
func Unauthorized(reason, message string) *ApplicationError {
return New(http.StatusUnauthorized, reason, message)
}
// IsUnauthorized determines if err is an error which indicates an Unauthorized error.
// It supports wrapped errors.
func IsUnauthorized(err error) bool {
return Code(err) == http.StatusUnauthorized
}
// Forbidden new Forbidden error that is mapped to a 403 response.
func Forbidden(reason, message string) *ApplicationError {
return New(http.StatusForbidden, reason, message)
}
// IsForbidden determines if err is an error which indicates a Forbidden error.
// It supports wrapped errors.
func IsForbidden(err error) bool {
return Code(err) == http.StatusForbidden
}
// NotFound new NotFound error that is mapped to a 404 response.
func NotFound(reason, message string) *ApplicationError {
return New(http.StatusNotFound, reason, message)
}
// IsNotFound determines if err is an error which indicates an NotFound error.
// It supports wrapped errors.
func IsNotFound(err error) bool {
return Code(err) == http.StatusNotFound
}
// Conflict new Conflict error that is mapped to a 409 response.
func Conflict(reason, message string) *ApplicationError {
return New(http.StatusConflict, reason, message)
}
// IsConflict determines if err is an error which indicates a Conflict error.
// It supports wrapped errors.
func IsConflict(err error) bool {
return Code(err) == http.StatusConflict
}
// InternalServer new InternalServer error that is mapped to a 500 response.
func InternalServer(reason, message string) *ApplicationError {
return New(http.StatusInternalServerError, reason, message)
}
// IsInternalServer determines if err is an error which indicates an Internal error.
// It supports wrapped errors.
func IsInternalServer(err error) bool {
return Code(err) == http.StatusInternalServerError
}
// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response.
func ServiceUnavailable(reason, message string) *ApplicationError {
return New(http.StatusServiceUnavailable, reason, message)
}
// IsServiceUnavailable determines if err is an error which indicates an Unavailable error.
// It supports wrapped errors.
func IsServiceUnavailable(err error) bool {
return Code(err) == http.StatusServiceUnavailable
}
// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response.
func GatewayTimeout(reason, message string) *ApplicationError {
return New(http.StatusGatewayTimeout, reason, message)
}
// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error.
// It supports wrapped errors.
func IsGatewayTimeout(err error) bool {
return Code(err) == http.StatusGatewayTimeout
}
// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response.
func ClientClosed(reason, message string) *ApplicationError {
return New(499, reason, message)
}
// IsClientClosed determines if err is an error which indicates a IsClientClosed error.
// It supports wrapped errors.
func IsClientClosed(err error) bool {
return Code(err) == 499
}

View File

@@ -3,9 +3,11 @@ package middleware
import (
"context"
"crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin"
)
@@ -96,7 +98,7 @@ func validateJWTForAdmin(
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if err == service.ErrTokenExpired {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}

View File

@@ -2,9 +2,11 @@ package middleware
import (
"context"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"strings"
"github.com/gin-gonic/gin"
)
@@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface {
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if err == service.ErrTokenExpired {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}

View File

@@ -4,14 +4,17 @@ import (
"math"
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
)
// Response 标准API响应格式
type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
Data any `json:"data,omitempty"`
}
// PaginatedData 分页数据格式(匹配前端期望)
@@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) {
// Error 返回错误响应
func Error(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Code: statusCode,
Message: message,
Reason: "",
Metadata: nil,
})
}
// ErrorWithDetails returns an error response compatible with the existing envelope while
// optionally providing structured error fields (reason/metadata).
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
c.JSON(statusCode, Response{
Code: statusCode,
Message: message,
Reason: reason,
Metadata: metadata,
})
}
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
// It returns true if an error was written.
func ErrorFrom(c *gin.Context, err error) bool {
if err == nil {
return false
}
statusCode, status := infraerrors.ToHTTP(err)
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)

View File

@@ -0,0 +1,171 @@
//go:build unit
package response
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
reason string
metadata map[string]string
want Response
}{
{
name: "plain_error",
statusCode: http.StatusBadRequest,
message: "invalid request",
want: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
},
},
{
name: "structured_error",
statusCode: http.StatusForbidden,
message: "no access",
reason: "FORBIDDEN",
metadata: map[string]string{"k": "v"},
want: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"k": "v"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
require.Equal(t, tt.statusCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.want, got)
})
}
}
func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
err error
wantWritten bool
wantHTTPCode int
wantBody Response
}{
{
name: "nil_error",
err: nil,
wantWritten: false,
},
{
name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true,
wantHTTPCode: http.StatusForbidden,
wantBody: Response{
Code: http.StatusForbidden,
Message: "no access",
Reason: "FORBIDDEN",
Metadata: map[string]string{"scope": "admin"},
},
},
{
name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true,
wantHTTPCode: http.StatusBadRequest,
wantBody: Response{
Code: http.StatusBadRequest,
Message: "invalid request",
Reason: "INVALID_REQUEST",
},
},
{
name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true,
wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{
Code: http.StatusUnauthorized,
Message: "unauthorized",
Reason: "UNAUTHORIZED",
},
},
{
name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"),
wantWritten: true,
wantHTTPCode: http.StatusNotFound,
wantBody: Response{
Code: http.StatusNotFound,
Message: "not found",
Reason: "NOT_FOUND",
},
},
{
name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"),
wantWritten: true,
wantHTTPCode: http.StatusConflict,
wantBody: Response{
Code: http.StatusConflict,
Message: "conflict",
Reason: "CONFLICT",
},
},
{
name: "unknown_error_defaults_to_500",
err: errors.New("boom"),
wantWritten: true,
wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{
Code: http.StatusInternalServerError,
Message: "boom",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
if !tt.wantWritten {
require.Equal(t, 200, w.Code)
require.Empty(t, w.Body.String())
return
}
require.Equal(t, tt.wantHTTPCode, w.Code)
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -13,23 +13,23 @@ import (
"gorm.io/gorm/clause"
)
type AccountRepository struct {
type accountRepository struct {
db *gorm.DB
}
func NewAccountRepository(db *gorm.DB) *AccountRepository {
return &AccountRepository{db: db}
func NewAccountRepository(db *gorm.DB) service.AccountRepository {
return &accountRepository{db: db}
}
func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error {
func (r *accountRepository) Create(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Create(account).Error
}
func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) {
var account model.Account
err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
// 填充 GroupIDs 和 Groups 虚拟字段
account.GroupIDs = make([]int64, 0, len(account.AccountGroups))
@@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou
return &account, nil
}
func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) {
if crsAccountID == "" {
return nil, nil
}
@@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
return &account, nil
}
func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error {
func (r *accountRepository) Update(ctx context.Context, account *model.Account) error {
return r.db.WithContext(ctx).Save(account).Error
}
func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
// 先删除账号与分组的绑定关系
if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
@@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
var total int64
@@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati
}, nil
}
func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.account_id = accounts.id").
@@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m
return accounts, err
}
func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).
@@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er
return accounts, err
}
func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error
}
func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
"status": model.StatusError,
@@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str
}).Error
}
func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error {
ag := &model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
@@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i
return r.db.WithContext(ctx).Create(ag).Error
}
func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID).
Delete(&model.AccountGroup{}).Error
}
func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).
Joins("JOIN account_groups ON account_groups.group_id = groups.id").
@@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m
return groups, err
}
func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
err := r.db.WithContext(ctx).
Where("platform = ? AND status = ?", platform, model.StatusActive).
@@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string)
return accounts, err
}
func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
// 删除现有绑定
if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil {
return err
@@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
// ListSchedulable 获取所有可调度的账号
func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
@@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun
}
// ListSchedulableByGroupID 按组获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
@@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI
}
// ListSchedulableByPlatform 按平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
@@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf
}
// ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号
func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) {
var accounts []model.Account
now := time.Now()
err := r.db.WithContext(ctx).
@@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont
}
// SetRateLimited 标记账号为限流状态(429)
func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
@@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
}
// SetOverloaded 标记账号为过载状态(529)
func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("overload_until", until).Error
}
// ClearRateLimit 清除账号的限流状态
func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error {
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]any{
"rate_limited_at": nil,
@@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error
}
// UpdateSessionWindow 更新账号的5小时时间窗口信息
func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
updates := map[string]any{
"session_window_status": status,
}
@@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
}
// SetSchedulable 设置账号的调度开关
func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Update("schedulable", schedulable).Error
}
// UpdateExtra updates specific fields in account's Extra JSONB field
// It merges the updates into existing Extra data without overwriting other fields
func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
}
@@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m
// BulkUpdate updates multiple accounts with the provided fields.
// It merges credentials/extra JSONB fields instead of overwriting them.
func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
if len(ids) == 0 {
return 0, nil
}

View File

@@ -18,13 +18,13 @@ type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *AccountRepository
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db)
s.repo = NewAccountRepository(s.db).(*accountRepository)
}
func TestAccountRepoSuite(t *testing.T) {
@@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db)
repo := NewAccountRepository(db).(*accountRepository)
ctx := context.Background()
tt.setup(db)

View File

@@ -2,51 +2,55 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type ApiKeyRepository struct {
type apiKeyRepository struct {
db *gorm.DB
}
func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository {
return &ApiKeyRepository{db: db}
func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository {
return &apiKeyRepository{db: db}
}
func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Create(key).Error
func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error {
err := r.db.WithContext(ctx).Create(key).Error
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
var key model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &key, nil
}
func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
var apiKey model.ApiKey
err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil)
}
return &apiKey, nil
}
func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error {
return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error
}
func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
}, nil
}
func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
}
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
var keys []model.ApiKey
db := r.db.WithContext(ctx).Model(&model.ApiKey{})
@@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.ApiKey{}).
Where("group_id = ?", groupID).
Update("group_id", nil)
@@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in
}
// CountByGroupID 获取分组的 API Key 数量
func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err

View File

@@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ApiKeyRepository
repo *apiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db)
s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository)
}
func TestApiKeyRepoSuite(t *testing.T) {

View File

@@ -0,0 +1,40 @@
package repository
import (
"errors"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"gorm.io/gorm"
)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return notFound.WithCause(err)
}
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
return err
}
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}

View File

@@ -2,47 +2,52 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type GroupRepository struct {
type groupRepository struct {
db *gorm.DB
}
func NewGroupRepository(db *gorm.DB) *GroupRepository {
return &GroupRepository{db: db}
func NewGroupRepository(db *gorm.DB) service.GroupRepository {
return &groupRepository{db: db}
}
func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Create(group).Error
func (r *groupRepository) Create(ctx context.Context, group *model.Group) error {
err := r.db.WithContext(ctx).Create(group).Error
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) {
var group model.Group
err := r.db.WithContext(ctx).First(&group, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return &group, nil
}
func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error {
func (r *groupRepository) Update(ctx context.Context, group *model.Group) error {
return r.db.WithContext(ctx).Save(group).Error
}
func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
}
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
var total int64
@@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination
}, nil
}
func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error
if err != nil {
@@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error)
return groups, nil
}
func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
var groups []model.Group
err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error
if err != nil {
@@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str
return groups, nil
}
func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error
return count > 0, err
}
func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error
return count, err
}
// DeleteAccountGroupsByGroupID 删除分组与账号的关联关系
func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{})
return result.RowsAffected, result.Error
}
// DB 返回底层数据库连接,用于事务处理
func (r *GroupRepository) DB() *gorm.DB {
return r.db
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
group, err := r.GetByID(ctx, id)
if err != nil {
return nil, err
}
var affectedUserIDs []int64
if group.IsSubscriptionType() {
var subscriptions []model.UserSubscription
if err := r.db.WithContext(ctx).
Model(&model.UserSubscription{}).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err != nil {
return nil, err
}
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 删除订阅类型分组的订阅记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return err
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return err
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return err
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return err
}
// 5. 删除分组本身(带锁,避免并发写)
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return affectedUserIDs, nil
}

View File

@@ -16,13 +16,13 @@ type GroupRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *GroupRepository
repo *groupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db)
s.repo = NewGroupRepository(s.db).(*groupRepository)
}
func TestGroupRepoSuite(t *testing.T) {
@@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}

View File

@@ -2,47 +2,50 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type ProxyRepository struct {
type proxyRepository struct {
db *gorm.DB
}
func NewProxyRepository(db *gorm.DB) *ProxyRepository {
return &ProxyRepository{db: db}
func NewProxyRepository(db *gorm.DB) service.ProxyRepository {
return &proxyRepository{db: db}
}
func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Create(proxy).Error
}
func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
var proxy model.Proxy
err := r.db.WithContext(ctx).First(&proxy, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil)
}
return &proxy, nil
}
func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error {
return r.db.WithContext(ctx).Save(proxy).Error
}
func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
}
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy
var total int64
@@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination
}, nil
}
func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error
return proxies, err
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Proxy{}).
Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password).
@@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.Account{}).
Where("proxy_id = ?", proxyID).
@@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
type result struct {
ProxyID int64 `gorm:"column:proxy_id"`
Count int64 `gorm:"column:count"`
@@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
var proxies []model.Proxy
err := r.db.WithContext(ctx).
Where("status = ?", model.StatusActive).

View File

@@ -17,13 +17,13 @@ type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ProxyRepository
repo *proxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db)
s.repo = NewProxyRepository(s.db).(*proxyRepository)
}
func TestProxyRepoSuite(t *testing.T) {

View File

@@ -2,57 +2,60 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
)
type RedeemCodeRepository struct {
type redeemCodeRepository struct {
db *gorm.DB
}
func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository {
return &RedeemCodeRepository{db: db}
func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository {
return &redeemCodeRepository{db: db}
}
func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Create(code).Error
}
func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error {
return r.db.WithContext(ctx).Create(&codes).Error
}
func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
var code model.RedeemCode
err := r.db.WithContext(ctx).First(&code, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &code, nil
}
func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
var redeemCode model.RedeemCode
err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil)
}
return &redeemCode, nil
}
func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
}
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode
var total int64
@@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
}, nil
}
func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error {
return r.db.WithContext(ctx).Save(code).Error
}
func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
result := r.db.WithContext(ctx).Model(&model.RedeemCode{}).
Where("id = ? AND status = ?", id, model.StatusUnused).
@@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 兑换码不存在或已被使用
return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound)
}
return nil
}
// ListByUser returns all redeem codes used by a specific user
func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
var codes []model.RedeemCode
if limit <= 0 {
limit = 10

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *RedeemCodeRepository
repo *redeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db)
s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
@@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
// Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
@@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
}
// --- ListByUser ---
@@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")

View File

@@ -1,14 +1,16 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/service"
// Repositories 所有仓库的集合
type Repositories struct {
User *UserRepository
ApiKey *ApiKeyRepository
Group *GroupRepository
Account *AccountRepository
Proxy *ProxyRepository
RedeemCode *RedeemCodeRepository
UsageLog *UsageLogRepository
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
User service.UserRepository
ApiKey service.ApiKeyRepository
Group service.GroupRepository
Account service.AccountRepository
Proxy service.ProxyRepository
RedeemCode service.RedeemCodeRepository
UsageLog service.UsageLogRepository
Setting service.SettingRepository
UserSubscription service.UserSubscriptionRepository
}

View File

@@ -2,35 +2,38 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/model"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// SettingRepository 系统设置数据访问层
type SettingRepository struct {
type settingRepository struct {
db *gorm.DB
}
// NewSettingRepository 创建系统设置仓库实例
func NewSettingRepository(db *gorm.DB) *SettingRepository {
return &SettingRepository{db: db}
func NewSettingRepository(db *gorm.DB) service.SettingRepository {
return &settingRepository{db: db}
}
// Get 根据Key获取设置值
func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) {
var setting model.Setting
err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil)
}
return &setting, nil
}
// GetValue 获取设置值字符串
func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) {
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
return "", err
@@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e
}
// Set 设置值(存在则更新,不存在则创建)
func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
setting := &model.Setting{
Key: key,
Value: value,
@@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error {
}
// GetMultiple 批量获取设置
func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
var settings []model.Setting
err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error
if err != nil {
@@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map
}
// SetMultiple 批量设置值
func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for key, value := range settings {
setting := &model.Setting{
@@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string
}
// GetAll 获取所有设置
func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) {
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
var settings []model.Setting
err := r.db.WithContext(ctx).Find(&settings).Error
if err != nil {
@@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro
}
// Delete 删除设置
func (r *SettingRepository) Delete(ctx context.Context, key string) error {
func (r *settingRepository) Delete(ctx context.Context, key string) error {
return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
@@ -14,13 +15,13 @@ type SettingRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *SettingRepository
repo *settingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db)
s.repo = NewSettingRepository(s.db).(*settingRepository)
}
func TestSettingRepoSuite(t *testing.T) {
@@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() {
func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrSettingNotFound)
}
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
@@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrSettingNotFound)
}
func (s *SettingRepoSuite) TestDelete_Idempotent() {

View File

@@ -2,25 +2,28 @@ package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm"
)
type UsageLogRepository struct {
type usageLogRepository struct {
db *gorm.DB
}
func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository {
return &UsageLogRepository{db: db}
func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository {
return &usageLogRepository{db: db}
}
// getPerformanceStats 获取 RPM 和 TPM近5分钟平均值可选按用户过滤
func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) {
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
var perfStats struct {
RequestCount int64 `gorm:"column:request_count"`
@@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int
return perfStats.RequestCount / 5, perfStats.TokenCount / 5
}
func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
var log model.UsageLog
err := r.db.WithContext(ctx).First(&log, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil)
}
return &log, nil
}
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil
}
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -120,7 +123,7 @@ type UserStats struct {
CacheReadTokens int64 `json:"cache_read_tokens"`
}
func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) {
var stats UserStats
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
Select(`
@@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
// DashboardStats 仪表盘统计
type DashboardStats = usagestats.DashboardStats
func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats
today := timezone.Today()
@@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil
}
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil
}
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err
}
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err
}
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err
}
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN
return logs, nil, err
}
func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
func (r *usageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
}
// GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today()
var stats struct {
@@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct {
Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"`
@@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint
type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) {
var results []ApiKeyUsageTrendPoint
// Choose date format based on granularity
@@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime,
}
// GetUserUsageTrend returns usage trend data grouped by user and date
func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) {
var results []UserUsageTrendPoint
// Choose date format based on granularity
@@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e
type UserDashboardStats = usagestats.UserDashboardStats
// GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
var stats UserDashboardStats
today := timezone.Today()
@@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
}
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) {
var results []TrendDataPoint
var dateFormat string
@@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user
}
// GetUserModelStats 获取指定用户的模型统计
func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) {
var results []ModelStat
err := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64
type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats
type BatchUserUsageStats = usagestats.BatchUserUsageStats
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
if len(userIDs) == 0 {
return make(map[int64]*BatchUserUsageStats), nil
}
@@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats
// GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys
func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) {
if len(apiKeyIDs) == 0 {
return make(map[int64]*BatchApiKeyUsageStats), nil
}
@@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) {
var results []TrendDataPoint
var dateFormat string
@@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) {
var results []ModelStat
db := r.db.WithContext(ctx).Model(&model.UsageLog{}).
@@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
}
// GetGlobalStats gets usage statistics for all users within a time range
func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
var stats struct {
TotalRequests int64 `gorm:"column:total_requests"`
TotalInputTokens int64 `gorm:"column:total_input_tokens"`
@@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary
type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse
// GetAccountUsageStats returns comprehensive usage statistics for an account over a time range
func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) {
daysCount := int(endTime.Sub(startTime).Hours()/24) + 1
if daysCount <= 0 {
daysCount = 30

View File

@@ -19,13 +19,13 @@ type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UsageLogRepository
repo *usageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db)
s.repo = NewUsageLogRepository(s.db).(*usageLogRepository)
}
func TestUsageLogRepoSuite(t *testing.T) {

View File

@@ -2,56 +2,61 @@ package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type UserRepository struct {
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
func NewUserRepository(db *gorm.DB) service.UserRepository {
return &userRepository{db: db}
}
func (r *UserRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Create(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}
func (r *UserRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
err := r.db.WithContext(ctx).Save(user).Error
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
var total int64
@@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.
}, nil
}
func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// DeductBalance 扣减用户余额,仅当余额充足时执行
func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("id = ? AND balance >= ?", id, amount).
Update("balance", gorm.Expr("balance - ?", amount))
@@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo
return result.Error
}
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound // 余额不足或用户不存在
return service.ErrInsufficientBalance
}
return nil
}
func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).
Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error
}
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
@@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool,
// RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID
// 使用 PostgreSQL 的 array_remove 函数
func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.User{}).
Where("? = ANY(allowed_groups)", groupID).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
@@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
}
// GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证)
func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) {
var user model.User
err := r.db.WithContext(ctx).
Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive).
Order("id ASC").
First(&user).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
return &user, nil
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
@@ -18,13 +19,13 @@ type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserRepository
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db)
s.repo = NewUserRepository(s.db).(*userRepository)
}
func TestUserRepoSuite(t *testing.T) {
@@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
s.Require().ErrorIs(err, service.ErrInsufficientBalance)
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
@@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)

View File

@@ -6,27 +6,29 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"gorm.io/gorm"
)
// UserSubscriptionRepository 用户订阅仓库
type UserSubscriptionRepository struct {
type userSubscriptionRepository struct {
db *gorm.DB
}
// NewUserSubscriptionRepository 创建用户订阅仓库
func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository {
return &UserSubscriptionRepository{db: db}
func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository {
return &userSubscriptionRepository{db: db}
}
// Create 创建订阅
func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
return r.db.WithContext(ctx).Create(sub).Error
func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error {
err := r.db.WithContext(ctx).Create(sub).Error
return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists)
}
// GetByID 根据ID获取订阅
func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("User").
@@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo
Preload("AssignedByUser").
First(&sub, id).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅
func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
Where("user_id = ? AND group_id = ?", userID, groupID).
First(&sub).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅
func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
var sub model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
@@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con
userID, groupID, model.SubscriptionStatusActive, time.Now()).
First(&sub).Error
if err != nil {
return nil, err
return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
return &sub, nil
}
// Update 更新订阅
func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error {
sub.UpdatedAt = time.Now()
return r.db.WithContext(ctx).Save(sub).Error
}
// Delete 删除订阅
func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error {
func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error
}
// ListByUserID 获取用户的所有订阅
func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
@@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in
}
// ListActiveByUserID 获取用户的所有有效订阅
func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Preload("Group").
@@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
// List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination
}
// IncrementUsage 增加使用量
func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}
// ResetDailyUsage 重置日使用量
func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int
}
// ResetWeeklyUsage 重置周使用量
func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in
}
// ResetMonthlyUsage 重置月使用量
func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
}
// ActivateWindows 激活所有窗口(首次使用时)
func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int
}
// UpdateStatus 更新订阅状态
func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64,
}
// ExtendExpiry 延长订阅过期时间
func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64,
}
// UpdateNotes 更新订阅备注
func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error {
return r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("id = ?", id).
Updates(map[string]any{
@@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64,
}
// ListExpired 获取所有已过期但状态仍为active的订阅
func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) {
var subs []model.UserSubscription
err := r.db.WithContext(ctx).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
@@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U
}
// BatchUpdateExpiredStatus 批量更新过期订阅状态
func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
result := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()).
Updates(map[string]any{
@@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex
}
// ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅
func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("user_id = ? AND group_id = ?", userID, groupID).
@@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex
}
// CountByGroupID 获取分组的订阅数量
func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ?", groupID).
@@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID
}
// CountActiveByGroupID 获取分组的有效订阅数量
func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&model.UserSubscription{}).
Where("group_id = ? AND status = ? AND expires_at > ?",
@@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g
}
// DeleteByGroupID 删除分组相关的所有订阅记录
func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) {
result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{})
return result.RowsAffected, result.Error
}

View File

@@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserSubscriptionRepository
repo *userSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db)
s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {

View File

@@ -1,7 +1,6 @@
package repository
import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
)
@@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet(
NewClaudeOAuthClient,
NewHTTPUpstream,
NewOpenAIOAuthClient,
// Bind concrete repositories to service port interfaces
wire.Bind(new(service.UserRepository), new(*UserRepository)),
wire.Bind(new(service.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(service.GroupRepository), new(*GroupRepository)),
wire.Bind(new(service.AccountRepository), new(*AccountRepository)),
wire.Bind(new(service.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(service.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(service.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(service.SettingRepository), new(*SettingRepository)),
wire.Bind(new(service.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
)

View File

@@ -2,17 +2,16 @@ package service
import (
"context"
"errors"
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrAccountNotFound = errors.New("account not found")
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
)
type AccountRepository interface {
@@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
@@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
@@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
@@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
@@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
@@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error {
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
@@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err)
}
@@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
// AdminService interface defines admin management operations
@@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组先获取受影响的用户ID列表用于事务后失效缓存
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil任何类型的分组都需要
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
}

View File

@@ -9,20 +9,20 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
@@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
@@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
@@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
@@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
@@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
@@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
@@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err)
}
@@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active")
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err)
}
@@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
@@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}

View File

@@ -8,22 +8,22 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
ErrServiceUnavailable = errors.New("service temporarily unavailable")
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
)
// JWTClaims JWT载荷数据
@@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrUserNotFound) {
return "", nil, ErrInvalidCredentials
}
// 记录数据库错误但不暴露给用户
@@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrUserNotFound) {
return "", ErrInvalidToken
}
log.Printf("[Auth] Database error refreshing token: %v", err)

View File

@@ -2,11 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
@@ -14,7 +14,7 @@ import (
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)

View File

@@ -4,21 +4,21 @@ import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
)
var (
ErrEmailNotConfigured = errors.New("email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
)
// EmailCache defines cache operations for email service

View File

@@ -2,17 +2,16 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrGroupNotFound = errors.New("group not found")
ErrGroupExists = errors.New("group name already exists")
ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
)
type GroupRepository interface {
@@ -20,6 +19,7 @@ type GroupRepository interface {
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
@@ -29,8 +29,6 @@ type GroupRepository interface {
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
}
// CreateGroupRequest 创建分组请求
@@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
return group, nil
@@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
@@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err)
}
@@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}

View File

@@ -2,16 +2,15 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
var (
ErrProxyNotFound = errors.New("proxy not found")
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
)
type ProxyRepository interface {
@@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
return proxy, nil
@@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
@@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
@@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error {
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
@@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err)
}

View File

@@ -9,19 +9,18 @@ import (
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrRedeemCodeNotFound = errors.New("redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
ErrInsufficientBalance = errors.New("insufficient balance")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
)
const (
@@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrRedeemCodeNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
@@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id")
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
@@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 兑换码已被其他请求使用
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
@@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
@@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
@@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return errors.New("cannot delete used redeem code")
return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {

View File

@@ -9,13 +9,13 @@ import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"gorm.io/gorm"
)
var (
ErrRegistrationDisabled = errors.New("registration is currently disabled")
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
)
type SettingRepository interface {
@@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 已有设置,不需要初始化
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
if !errors.Is(err, ErrSettingNotFound) {
return fmt.Errorf("check existing settings: %w", err)
}
@@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
}
return "", false, err
@@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
}
return "", err // 数据库错误

View File

@@ -2,24 +2,24 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found")
ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired")
ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended")
ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group")
ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type")
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
)
// SubscriptionService 订阅服务

View File

@@ -2,14 +2,15 @@ package service
import (
"context"
"errors"
"fmt"
"log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
)
var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
)
// TurnstileVerifier 验证 Turnstile token 的接口

View File

@@ -2,18 +2,17 @@ package service
import (
"context"
"errors"
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"gorm.io/gorm"
)
var (
ErrUsageLogNotFound = errors.New("usage log not found")
ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
@@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
@@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil

View File

@@ -2,19 +2,18 @@ package service
import (
"context"
"errors"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
type UserRepository interface {
@@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService {
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
@@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
@@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
@@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
@@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}