From eeaff85e471f99c55b0d446c88a73f89425da9a7 Mon Sep 17 00:00:00 2001 From: Forest Date: Thu, 25 Dec 2025 20:52:47 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E4=B8=9A=E5=8A=A1=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.golangci.yml | 6 +- .../internal/handler/admin/account_handler.go | 44 ++--- .../internal/handler/admin/group_handler.go | 14 +- .../handler/admin/openai_oauth_handler.go | 16 +- .../internal/handler/admin/proxy_handler.go | 20 +- .../internal/handler/admin/redeem_handler.go | 14 +- .../internal/handler/admin/setting_handler.go | 16 +- .../handler/admin/subscription_handler.go | 16 +- .../internal/handler/admin/usage_handler.go | 12 +- .../internal/handler/admin/user_handler.go | 16 +- backend/internal/handler/api_key_handler.go | 12 +- backend/internal/handler/auth_handler.go | 12 +- backend/internal/handler/redeem_handler.go | 4 +- backend/internal/handler/setting_handler.go | 2 +- .../internal/handler/subscription_handler.go | 8 +- backend/internal/handler/usage_handler.go | 18 +- backend/internal/handler/user_handler.go | 6 +- .../internal/infrastructure/errors/errors.go | 157 ++++++++++++++++ .../infrastructure/errors/errors_test.go | 168 +++++++++++++++++ .../internal/infrastructure/errors/http.go | 21 +++ .../internal/infrastructure/errors/types.go | 114 ++++++++++++ backend/internal/middleware/admin_auth.go | 6 +- backend/internal/middleware/jwt_auth.go | 6 +- backend/internal/pkg/response/response.go | 38 +++- .../internal/pkg/response/response_test.go | 171 ++++++++++++++++++ backend/internal/repository/account_repo.go | 62 +++---- .../account_repo_integration_test.go | 6 +- backend/internal/repository/api_key_repo.go | 40 ++-- .../api_key_repo_integration_test.go | 4 +- .../internal/repository/error_translate.go | 40 ++++ backend/internal/repository/group_repo.go | 98 ++++++++-- .../repository/group_repo_integration_test.go | 12 +- backend/internal/repository/proxy_repo.go | 33 ++-- .../repository/proxy_repo_integration_test.go | 4 +- .../internal/repository/redeem_code_repo.go | 37 ++-- .../redeem_code_repo_integration_test.go | 11 +- backend/internal/repository/repository.go | 20 +- backend/internal/repository/setting_repo.go | 27 +-- .../setting_repo_integration_test.go | 9 +- backend/internal/repository/usage_log_repo.go | 67 +++---- .../usage_log_repo_integration_test.go | 4 +- backend/internal/repository/user_repo.go | 49 ++--- .../repository/user_repo_integration_test.go | 9 +- .../repository/user_subscription_repo.go | 64 +++---- ...user_subscription_repo_integration_test.go | 4 +- backend/internal/repository/wire.go | 12 -- backend/internal/service/account_service.go | 29 +-- backend/internal/service/admin_service.go | 57 +----- backend/internal/service/api_key_service.go | 45 +---- backend/internal/service/auth_service.go | 22 +-- .../internal/service/billing_cache_service.go | 4 +- backend/internal/service/email_service.go | 10 +- backend/internal/service/group_service.go | 22 +-- backend/internal/service/proxy_service.go | 20 +- backend/internal/service/redeem_service.go | 34 +--- backend/internal/service/setting_service.go | 12 +- .../internal/service/subscription_service.go | 18 +- backend/internal/service/turnstile_service.go | 7 +- backend/internal/service/usage_service.go | 11 +- backend/internal/service/user_service.go | 24 +-- 60 files changed, 1222 insertions(+), 622 deletions(-) create mode 100644 backend/internal/infrastructure/errors/errors.go create mode 100644 backend/internal/infrastructure/errors/errors_test.go create mode 100644 backend/internal/infrastructure/errors/http.go create mode 100644 backend/internal/infrastructure/errors/types.go create mode 100644 backend/internal/pkg/response/response_test.go create mode 100644 backend/internal/repository/error_translate.go diff --git a/backend/.golangci.yml b/backend/.golangci.yml index ec16bc0f..6d078f1f 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -19,14 +19,16 @@ linters: files: - "**/internal/service/**" deny: - - pkg: sub2api/internal/repository + - pkg: github.com/Wei-Shaw/sub2api/internal/repository desc: "service must not import repository" + - pkg: gorm.io/gorm + desc: "service must not import gorm" handler-no-repository: list-mode: original files: - "**/internal/handler/**" deny: - - pkg: sub2api/internal/repository + - pkg: github.com/Wei-Shaw/sub2api/internal/repository desc: "handler must not import repository" errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 25f69588..8ecb4326 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -117,7 +117,7 @@ func (h *AccountHandler) List(c *gin.Context) { accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) if err != nil { - response.InternalError(c, "Failed to list accounts: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -156,7 +156,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) { account, err := h.adminService.GetAccount(c.Request.Context(), accountID) if err != nil { - response.NotFound(c, "Account not found") + response.ErrorFrom(c, err) return } @@ -184,7 +184,7 @@ func (h *AccountHandler) Create(c *gin.Context) { GroupIDs: req.GroupIDs, }) if err != nil { - response.BadRequest(c, "Failed to create account: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -218,7 +218,7 @@ func (h *AccountHandler) Update(c *gin.Context) { GroupIDs: req.GroupIDs, }) if err != nil { - response.InternalError(c, "Failed to update account: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -236,7 +236,7 @@ func (h *AccountHandler) Delete(c *gin.Context) { err = h.adminService.DeleteAccount(c.Request.Context(), accountID) if err != nil { - response.InternalError(c, "Failed to delete account: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -297,7 +297,7 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { SyncProxies: syncProxies, }) if err != nil { - response.BadRequest(c, "Sync failed: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -332,7 +332,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { // Use OpenAI OAuth service to refresh token tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) if err != nil { - response.InternalError(c, "Failed to refresh credentials: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -349,7 +349,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { // Use Anthropic/Claude OAuth service to refresh token tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) if err != nil { - response.InternalError(c, "Failed to refresh credentials: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -372,7 +372,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { Credentials: newCredentials, }) if err != nil { - response.InternalError(c, "Failed to update account credentials: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -403,7 +403,7 @@ func (h *AccountHandler) GetStats(c *gin.Context) { stats, err := h.accountUsageService.GetAccountUsageStats(c.Request.Context(), accountID, startTime, endTime) if err != nil { - response.InternalError(c, "Failed to get account stats: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -421,7 +421,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID) if err != nil { - response.InternalError(c, "Failed to clear error: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -570,7 +570,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Extra: req.Extra, }) if err != nil { - response.InternalError(c, "Failed to bulk update accounts: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -595,7 +595,7 @@ func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) { result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID) if err != nil { - response.InternalError(c, "Failed to generate auth URL: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -613,7 +613,7 @@ func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) { result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID) if err != nil { - response.InternalError(c, "Failed to generate setup token URL: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -642,7 +642,7 @@ func (h *OAuthHandler) ExchangeCode(c *gin.Context) { ProxyID: req.ProxyID, }) if err != nil { - response.BadRequest(c, "Failed to exchange code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -664,7 +664,7 @@ func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) { ProxyID: req.ProxyID, }) if err != nil { - response.BadRequest(c, "Failed to exchange code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -692,7 +692,7 @@ func (h *OAuthHandler) CookieAuth(c *gin.Context) { Scope: "full", }) if err != nil { - response.BadRequest(c, "Cookie auth failed: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -714,7 +714,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) { Scope: "inference", }) if err != nil { - response.BadRequest(c, "Cookie auth failed: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -732,7 +732,7 @@ func (h *AccountHandler) GetUsage(c *gin.Context) { usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID) if err != nil { - response.InternalError(c, "Failed to get usage: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -750,7 +750,7 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID) if err != nil { - response.InternalError(c, "Failed to clear rate limit: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -768,7 +768,7 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) { stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID) if err != nil { - response.InternalError(c, "Failed to get today stats: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -797,7 +797,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable) if err != nil { - response.InternalError(c, "Failed to update schedulable status: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 00ede072..26d0715f 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -65,7 +65,7 @@ func (h *GroupHandler) List(c *gin.Context) { groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) if err != nil { - response.InternalError(c, "Failed to list groups: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -87,7 +87,7 @@ func (h *GroupHandler) GetAll(c *gin.Context) { } if err != nil { - response.InternalError(c, "Failed to get groups: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -105,7 +105,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) { group, err := h.adminService.GetGroup(c.Request.Context(), groupID) if err != nil { - response.NotFound(c, "Group not found") + response.ErrorFrom(c, err) return } @@ -133,7 +133,7 @@ func (h *GroupHandler) Create(c *gin.Context) { MonthlyLimitUSD: req.MonthlyLimitUSD, }) if err != nil { - response.BadRequest(c, "Failed to create group: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -168,7 +168,7 @@ func (h *GroupHandler) Update(c *gin.Context) { MonthlyLimitUSD: req.MonthlyLimitUSD, }) if err != nil { - response.InternalError(c, "Failed to update group: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -186,7 +186,7 @@ func (h *GroupHandler) Delete(c *gin.Context) { err = h.adminService.DeleteGroup(c.Request.Context(), groupID) if err != nil { - response.InternalError(c, "Failed to delete group: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -225,7 +225,7 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize) if err != nil { - response.InternalError(c, "Failed to get group API keys: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index 5b9ff39a..60285fe3 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -40,7 +40,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) if err != nil { - response.InternalError(c, "Failed to generate auth URL: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -71,7 +71,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { ProxyID: req.ProxyID, }) if err != nil { - response.BadRequest(c, "Failed to exchange code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -103,7 +103,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) if err != nil { - response.BadRequest(c, "Failed to refresh token: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -122,7 +122,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { // Get account account, err := h.adminService.GetAccount(c.Request.Context(), accountID) if err != nil { - response.NotFound(c, "Account not found") + response.ErrorFrom(c, err) return } @@ -141,7 +141,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { // Use OpenAI OAuth service to refresh token tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) if err != nil { - response.InternalError(c, "Failed to refresh credentials: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -159,7 +159,7 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { Credentials: newCredentials, }) if err != nil { - response.InternalError(c, "Failed to update account credentials: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -192,7 +192,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { ProxyID: req.ProxyID, }) if err != nil { - response.BadRequest(c, "Failed to exchange code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -220,7 +220,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { GroupIDs: req.GroupIDs, }) if err != nil { - response.InternalError(c, "Failed to create account: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 3acca977..99937fb9 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -53,7 +53,7 @@ func (h *ProxyHandler) List(c *gin.Context) { proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) if err != nil { - response.InternalError(c, "Failed to list proxies: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -69,7 +69,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { if withCount { proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get proxies: "+err.Error()) + response.ErrorFrom(c, err) return } response.Success(c, proxies) @@ -78,7 +78,7 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { proxies, err := h.adminService.GetAllProxies(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get proxies: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -96,7 +96,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) { proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) if err != nil { - response.NotFound(c, "Proxy not found") + response.ErrorFrom(c, err) return } @@ -121,7 +121,7 @@ func (h *ProxyHandler) Create(c *gin.Context) { Password: strings.TrimSpace(req.Password), }) if err != nil { - response.BadRequest(c, "Failed to create proxy: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -153,7 +153,7 @@ func (h *ProxyHandler) Update(c *gin.Context) { Status: strings.TrimSpace(req.Status), }) if err != nil { - response.InternalError(c, "Failed to update proxy: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -171,7 +171,7 @@ func (h *ProxyHandler) Delete(c *gin.Context) { err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) if err != nil { - response.InternalError(c, "Failed to delete proxy: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -189,7 +189,7 @@ func (h *ProxyHandler) Test(c *gin.Context) { result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) if err != nil { - response.InternalError(c, "Failed to test proxy: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -229,7 +229,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) if err != nil { - response.InternalError(c, "Failed to get proxy accounts: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -272,7 +272,7 @@ func (h *ProxyHandler) BatchCreate(c *gin.Context) { // Check for duplicates (same host, port, username, password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) if err != nil { - response.InternalError(c, "Failed to check proxy existence: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 7dffc363..98a4c242 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -43,7 +43,7 @@ func (h *RedeemHandler) List(c *gin.Context) { codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) if err != nil { - response.InternalError(c, "Failed to list redeem codes: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -61,7 +61,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) { code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID) if err != nil { - response.NotFound(c, "Redeem code not found") + response.ErrorFrom(c, err) return } @@ -85,7 +85,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) { ValidityDays: req.ValidityDays, }) if err != nil { - response.InternalError(c, "Failed to generate redeem codes: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -103,7 +103,7 @@ func (h *RedeemHandler) Delete(c *gin.Context) { err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID) if err != nil { - response.InternalError(c, "Failed to delete redeem code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -123,7 +123,7 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) { deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs) if err != nil { - response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -144,7 +144,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) { code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID) if err != nil { - response.InternalError(c, "Failed to expire redeem code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -178,7 +178,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { // Get all codes without pagination (use large page size) codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "") if err != nil { - response.InternalError(c, "Failed to export redeem codes: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 851dbc88..3cdd6a9d 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -27,7 +27,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser func (h *SettingHandler) GetSettings(c *gin.Context) { settings, err := h.settingService.GetAllSettings(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get settings: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -111,14 +111,14 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { - response.InternalError(c, "Failed to update settings: "+err.Error()) + response.ErrorFrom(c, err) return } // 重新获取设置返回 updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get updated settings: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -166,7 +166,7 @@ func (h *SettingHandler) TestSmtpConnection(c *gin.Context) { err := h.emailService.TestSmtpConnectionWithConfig(config) if err != nil { - response.BadRequest(c, "SMTP connection test failed: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -252,7 +252,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { ` if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil { - response.BadRequest(c, "Failed to send test email: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -264,7 +264,7 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) { func (h *SettingHandler) GetAdminApiKey(c *gin.Context) { maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get admin API key status: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -279,7 +279,7 @@ func (h *SettingHandler) GetAdminApiKey(c *gin.Context) { func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) { key, err := h.settingService.GenerateAdminApiKey(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to generate admin API key: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -292,7 +292,7 @@ func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) { // DELETE /api/v1/admin/settings/admin-api-key func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) { if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil { - response.InternalError(c, "Failed to delete admin API key: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index d6f0a4e5..d101a6e6 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -78,7 +78,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) { subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status) if err != nil { - response.InternalError(c, "Failed to list subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -96,7 +96,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) { subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID) if err != nil { - response.NotFound(c, "Subscription not found") + response.ErrorFrom(c, err) return } @@ -141,7 +141,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) { Notes: req.Notes, }) if err != nil { - response.BadRequest(c, "Failed to assign subscription: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -168,7 +168,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) { Notes: req.Notes, }) if err != nil { - response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -192,7 +192,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) if err != nil { - response.InternalError(c, "Failed to extend subscription: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -210,7 +210,7 @@ func (h *SubscriptionHandler) Revoke(c *gin.Context) { err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID) if err != nil { - response.InternalError(c, "Failed to revoke subscription: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -230,7 +230,7 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) { subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize) if err != nil { - response.InternalError(c, "Failed to list group subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -248,7 +248,7 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) { subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID) if err != nil { - response.InternalError(c, "Failed to list user subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 523dc689..6c230c56 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -90,7 +90,7 @@ func (h *UsageHandler) List(c *gin.Context) { records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) if err != nil { - response.InternalError(c, "Failed to list usage records: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -158,7 +158,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { if apiKeyID > 0 { stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime) if err != nil { - response.InternalError(c, "Failed to get usage statistics: "+err.Error()) + response.ErrorFrom(c, err) return } response.Success(c, stats) @@ -168,7 +168,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { if userID > 0 { stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime) if err != nil { - response.InternalError(c, "Failed to get usage statistics: "+err.Error()) + response.ErrorFrom(c, err) return } response.Success(c, stats) @@ -178,7 +178,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { // Get global stats stats, err := h.usageService.GetGlobalStats(c.Request.Context(), startTime, endTime) if err != nil { - response.InternalError(c, "Failed to get usage statistics: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -197,7 +197,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) { // Limit to 30 results users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword) if err != nil { - response.InternalError(c, "Failed to search users: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -236,7 +236,7 @@ func (h *UsageHandler) SearchApiKeys(c *gin.Context) { keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30) if err != nil { - response.InternalError(c, "Failed to search API keys: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 64afd886..b8842ae0 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -64,7 +64,7 @@ func (h *UserHandler) List(c *gin.Context) { users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search) if err != nil { - response.InternalError(c, "Failed to list users: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -82,7 +82,7 @@ func (h *UserHandler) GetByID(c *gin.Context) { user, err := h.adminService.GetUser(c.Request.Context(), userID) if err != nil { - response.NotFound(c, "User not found") + response.ErrorFrom(c, err) return } @@ -109,7 +109,7 @@ func (h *UserHandler) Create(c *gin.Context) { AllowedGroups: req.AllowedGroups, }) if err != nil { - response.BadRequest(c, "Failed to create user: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -144,7 +144,7 @@ func (h *UserHandler) Update(c *gin.Context) { AllowedGroups: req.AllowedGroups, }) if err != nil { - response.InternalError(c, "Failed to update user: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -162,7 +162,7 @@ func (h *UserHandler) Delete(c *gin.Context) { err = h.adminService.DeleteUser(c.Request.Context(), userID) if err != nil { - response.InternalError(c, "Failed to delete user: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -186,7 +186,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) if err != nil { - response.InternalError(c, "Failed to update balance: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -206,7 +206,7 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) { keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) if err != nil { - response.InternalError(c, "Failed to get user API keys: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -226,7 +226,7 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period) if err != nil { - response.InternalError(c, "Failed to get user usage: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 83e4d13d..8592e38b 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -57,7 +57,7 @@ func (h *APIKeyHandler) List(c *gin.Context) { keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) if err != nil { - response.InternalError(c, "Failed to list API keys: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -87,7 +87,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) { key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID) if err != nil { - response.NotFound(c, "API key not found") + response.ErrorFrom(c, err) return } @@ -128,7 +128,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq) if err != nil { - response.InternalError(c, "Failed to create API key: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -173,7 +173,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) { key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq) if err != nil { - response.InternalError(c, "Failed to update API key: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -203,7 +203,7 @@ func (h *APIKeyHandler) Delete(c *gin.Context) { err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID) if err != nil { - response.InternalError(c, "Failed to delete API key: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -227,7 +227,7 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID) if err != nil { - response.InternalError(c, "Failed to get available groups: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index a5347d0c..efb7584d 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -66,14 +66,14 @@ func (h *AuthHandler) Register(c *gin.Context) { // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) if req.VerifyCode == "" { if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { - response.BadRequest(c, "Turnstile verification failed: "+err.Error()) + response.ErrorFrom(c, err) return } } token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode) if err != nil { - response.BadRequest(c, "Registration failed: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -95,13 +95,13 @@ func (h *AuthHandler) SendVerifyCode(c *gin.Context) { // Turnstile 验证 if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { - response.BadRequest(c, "Turnstile verification failed: "+err.Error()) + response.ErrorFrom(c, err) return } result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email) if err != nil { - response.BadRequest(c, "Failed to send verification code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -122,13 +122,13 @@ func (h *AuthHandler) Login(c *gin.Context) { // Turnstile 验证 if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil { - response.BadRequest(c, "Turnstile verification failed: "+err.Error()) + response.ErrorFrom(c, err) return } token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password) if err != nil { - response.Unauthorized(c, "Login failed: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/redeem_handler.go b/backend/internal/handler/redeem_handler.go index 8a4399ae..765d2e26 100644 --- a/backend/internal/handler/redeem_handler.go +++ b/backend/internal/handler/redeem_handler.go @@ -57,7 +57,7 @@ func (h *RedeemHandler) Redeem(c *gin.Context) { result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code) if err != nil { - response.BadRequest(c, "Failed to redeem code: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -84,7 +84,7 @@ func (h *RedeemHandler) GetHistory(c *gin.Context) { codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit) if err != nil { - response.InternalError(c, "Failed to get history: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 94072fe1..d9804865 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -26,7 +26,7 @@ func NewSettingHandler(settingService *service.SettingService, version string) * func (h *SettingHandler) GetPublicSettings(c *gin.Context) { settings, err := h.settingService.GetPublicSettings(c.Request.Context()) if err != nil { - response.InternalError(c, "Failed to get settings: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/subscription_handler.go b/backend/internal/handler/subscription_handler.go index eb597d9e..fd67e529 100644 --- a/backend/internal/handler/subscription_handler.go +++ b/backend/internal/handler/subscription_handler.go @@ -58,7 +58,7 @@ func (h *SubscriptionHandler) List(c *gin.Context) { subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID) if err != nil { - response.InternalError(c, "Failed to list subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -82,7 +82,7 @@ func (h *SubscriptionHandler) GetActive(c *gin.Context) { subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) if err != nil { - response.InternalError(c, "Failed to get active subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -107,7 +107,7 @@ func (h *SubscriptionHandler) GetProgress(c *gin.Context) { // Get all active subscriptions with progress subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) if err != nil { - response.InternalError(c, "Failed to get subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -146,7 +146,7 @@ func (h *SubscriptionHandler) GetSummary(c *gin.Context) { // Get all active subscriptions subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID) if err != nil { - response.InternalError(c, "Failed to get subscriptions: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index e1fd332e..d73df209 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -55,7 +55,7 @@ func (h *UsageHandler) List(c *gin.Context) { // [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) if err != nil { - response.NotFound(c, "API key not found") + response.ErrorFrom(c, err) return } if apiKey.UserID != user.ID { @@ -77,7 +77,7 @@ func (h *UsageHandler) List(c *gin.Context) { records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params) } if err != nil { - response.InternalError(c, "Failed to list usage records: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -107,7 +107,7 @@ func (h *UsageHandler) GetByID(c *gin.Context) { record, err := h.usageService.GetByID(c.Request.Context(), usageID) if err != nil { - response.NotFound(c, "Usage record not found") + response.ErrorFrom(c, err) return } @@ -204,7 +204,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime) } if err != nil { - response.InternalError(c, "Failed to get usage statistics: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -259,7 +259,7 @@ func (h *UsageHandler) DashboardStats(c *gin.Context) { stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), user.ID) if err != nil { - response.InternalError(c, "Failed to get dashboard statistics") + response.ErrorFrom(c, err) return } @@ -286,7 +286,7 @@ func (h *UsageHandler) DashboardTrend(c *gin.Context) { trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity) if err != nil { - response.InternalError(c, "Failed to get usage trend") + response.ErrorFrom(c, err) return } @@ -317,7 +317,7 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) { stats, err := h.usageService.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime) if err != nil { - response.InternalError(c, "Failed to get model statistics") + response.ErrorFrom(c, err) return } @@ -362,7 +362,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { // Verify ownership of all requested API keys userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) if err != nil { - response.InternalError(c, "Failed to verify API key ownership") + response.ErrorFrom(c, err) return } @@ -386,7 +386,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs) if err != nil { - response.InternalError(c, "Failed to get API key usage stats") + response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 435149b8..4c5498f0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -49,7 +49,7 @@ func (h *UserHandler) GetProfile(c *gin.Context) { userData, err := h.userService.GetByID(c.Request.Context(), user.ID) if err != nil { - response.InternalError(c, "Failed to get user profile: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -86,7 +86,7 @@ func (h *UserHandler) ChangePassword(c *gin.Context) { } err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq) if err != nil { - response.BadRequest(c, "Failed to change password: "+err.Error()) + response.ErrorFrom(c, err) return } @@ -120,7 +120,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { } updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), user.ID, svcReq) if err != nil { - response.BadRequest(c, "Failed to update profile: "+err.Error()) + response.ErrorFrom(c, err) return } diff --git a/backend/internal/infrastructure/errors/errors.go b/backend/internal/infrastructure/errors/errors.go new file mode 100644 index 00000000..64a98cc2 --- /dev/null +++ b/backend/internal/infrastructure/errors/errors.go @@ -0,0 +1,157 @@ +package errors + +import ( + "errors" + "fmt" + "net/http" +) + +const ( + UnknownCode = http.StatusInternalServerError + UnknownReason = "" +) + +type Status struct { + Code int32 `json:"code"` + Reason string `json:"reason,omitempty"` + Message string `json:"message"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ApplicationError is the standard error type used to control HTTP responses. +// +// Code is expected to be an HTTP status code (e.g. 400/401/403/404/409/500). +type ApplicationError struct { + Status + cause error +} + +// Error is kept for backwards compatibility within this package. +type Error = ApplicationError + +func (e *ApplicationError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v", e.Code, e.Reason, e.Message, e.Metadata) + } + return fmt.Sprintf("error: code=%d reason=%q message=%q metadata=%v cause=%v", e.Code, e.Reason, e.Message, e.Metadata, e.cause) +} + +// Unwrap provides compatibility for Go 1.13 error chains. +func (e *ApplicationError) Unwrap() error { return e.cause } + +// Is matches each error in the chain with the target value. +func (e *ApplicationError) Is(err error) bool { + if se := new(ApplicationError); errors.As(err, &se) { + return se.Code == e.Code && se.Reason == e.Reason + } + return false +} + +// WithCause attaches the underlying cause of the error. +func (e *ApplicationError) WithCause(cause error) *ApplicationError { + err := Clone(e) + err.cause = cause + return err +} + +// WithMetadata deep-copies the given metadata map. +func (e *ApplicationError) WithMetadata(md map[string]string) *ApplicationError { + err := Clone(e) + if md == nil { + err.Metadata = nil + return err + } + err.Metadata = make(map[string]string, len(md)) + for k, v := range md { + err.Metadata[k] = v + } + return err +} + +// New returns an error object for the code, message. +func New(code int, reason, message string) *ApplicationError { + return &ApplicationError{ + Status: Status{ + Code: int32(code), + Message: message, + Reason: reason, + }, + } +} + +// Newf New(code fmt.Sprintf(format, a...)) +func Newf(code int, reason, format string, a ...any) *ApplicationError { + return New(code, reason, fmt.Sprintf(format, a...)) +} + +// Errorf returns an error object for the code, message and error info. +func Errorf(code int, reason, format string, a ...any) error { + return New(code, reason, fmt.Sprintf(format, a...)) +} + +// Code returns the http code for an error. +// It supports wrapped errors. +func Code(err error) int { + if err == nil { + return http.StatusOK + } + return int(FromError(err).Code) +} + +// Reason returns the reason for a particular error. +// It supports wrapped errors. +func Reason(err error) string { + if err == nil { + return UnknownReason + } + return FromError(err).Reason +} + +// Message returns the message for a particular error. +// It supports wrapped errors. +func Message(err error) string { + if err == nil { + return "" + } + return FromError(err).Message +} + +// Clone deep clone error to a new error. +func Clone(err *ApplicationError) *ApplicationError { + if err == nil { + return nil + } + var metadata map[string]string + if err.Metadata != nil { + metadata = make(map[string]string, len(err.Metadata)) + for k, v := range err.Metadata { + metadata[k] = v + } + } + return &ApplicationError{ + cause: err.cause, + Status: Status{ + Code: err.Code, + Reason: err.Reason, + Message: err.Message, + Metadata: metadata, + }, + } +} + +// FromError tries to convert an error to *ApplicationError. +// It supports wrapped errors. +func FromError(err error) *ApplicationError { + if err == nil { + return nil + } + if se := new(ApplicationError); errors.As(err, &se) { + return se + } + + // Fall back to a generic internal error. + return New(UnknownCode, UnknownReason, err.Error()).WithCause(err) +} diff --git a/backend/internal/infrastructure/errors/errors_test.go b/backend/internal/infrastructure/errors/errors_test.go new file mode 100644 index 00000000..8170ca26 --- /dev/null +++ b/backend/internal/infrastructure/errors/errors_test.go @@ -0,0 +1,168 @@ +//go:build unit + +package errors + +import ( + stderrors "errors" + "fmt" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplicationError_Basics(t *testing.T) { + tests := []struct { + name string + err *ApplicationError + want Status + wantIs bool + target error + wrapped error + }{ + { + name: "new", + err: New(400, "BAD_REQUEST", "invalid input"), + want: Status{ + Code: 400, + Reason: "BAD_REQUEST", + Message: "invalid input", + }, + }, + { + name: "is_matches_code_and_reason", + err: New(401, "UNAUTHORIZED", "nope"), + want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"}, + target: New(401, "UNAUTHORIZED", "ignored message"), + wantIs: true, + }, + { + name: "is_does_not_match_reason", + err: New(401, "UNAUTHORIZED", "nope"), + want: Status{Code: 401, Reason: "UNAUTHORIZED", Message: "nope"}, + target: New(401, "DIFFERENT", "ignored message"), + wantIs: false, + }, + { + name: "from_error_unwraps_wrapped_application_error", + err: New(404, "NOT_FOUND", "missing"), + wrapped: fmt.Errorf("wrap: %w", New(404, "NOT_FOUND", "missing")), + want: Status{ + Code: 404, + Reason: "NOT_FOUND", + Message: "missing", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err != nil { + require.Equal(t, tt.want, tt.err.Status) + } + + if tt.target != nil { + require.Equal(t, tt.wantIs, stderrors.Is(tt.err, tt.target)) + } + + if tt.wrapped != nil { + got := FromError(tt.wrapped) + require.Equal(t, tt.want, got.Status) + } + }) + } +} + +func TestApplicationError_WithMetadataDeepCopy(t *testing.T) { + tests := []struct { + name string + md map[string]string + }{ + {name: "non_nil", md: map[string]string{"a": "1"}}, + {name: "nil", md: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + appErr := BadRequest("BAD_REQUEST", "invalid input").WithMetadata(tt.md) + + if tt.md == nil { + require.Nil(t, appErr.Metadata) + return + } + + tt.md["a"] = "changed" + require.Equal(t, "1", appErr.Metadata["a"]) + }) + } +} + +func TestFromError_Generic(t *testing.T) { + tests := []struct { + name string + err error + wantCode int32 + wantReason string + wantMsg string + }{ + { + name: "plain_error", + err: stderrors.New("boom"), + wantCode: UnknownCode, + wantReason: UnknownReason, + wantMsg: "boom", + }, + { + name: "wrapped_plain_error", + err: fmt.Errorf("wrap: %w", io.EOF), + wantCode: UnknownCode, + wantReason: UnknownReason, + wantMsg: "wrap: EOF", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := FromError(tt.err) + require.Equal(t, tt.wantCode, got.Code) + require.Equal(t, tt.wantReason, got.Reason) + require.Equal(t, tt.wantMsg, got.Message) + require.Equal(t, tt.err, got.Unwrap()) + }) + } +} + +func TestToHTTP(t *testing.T) { + tests := []struct { + name string + err error + wantStatusCode int + wantBody Status + }{ + { + name: "nil_error", + err: nil, + wantStatusCode: http.StatusOK, + wantBody: Status{Code: int32(http.StatusOK)}, + }, + { + name: "application_error", + err: Forbidden("FORBIDDEN", "no access"), + wantStatusCode: http.StatusForbidden, + wantBody: Status{ + Code: int32(http.StatusForbidden), + Reason: "FORBIDDEN", + Message: "no access", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + code, body := ToHTTP(tt.err) + require.Equal(t, tt.wantStatusCode, code) + require.Equal(t, tt.wantBody, body) + }) + } +} diff --git a/backend/internal/infrastructure/errors/http.go b/backend/internal/infrastructure/errors/http.go new file mode 100644 index 00000000..7b5560e3 --- /dev/null +++ b/backend/internal/infrastructure/errors/http.go @@ -0,0 +1,21 @@ +package errors + +import "net/http" + +// ToHTTP converts an error into an HTTP status code and a JSON-serializable body. +// +// The returned body matches the project's Status shape: +// { code, reason, message, metadata }. +func ToHTTP(err error) (statusCode int, body Status) { + if err == nil { + return http.StatusOK, Status{Code: int32(http.StatusOK)} + } + + appErr := FromError(err) + if appErr == nil { + return http.StatusOK, Status{Code: int32(http.StatusOK)} + } + + cloned := Clone(appErr) + return int(cloned.Code), cloned.Status +} diff --git a/backend/internal/infrastructure/errors/types.go b/backend/internal/infrastructure/errors/types.go new file mode 100644 index 00000000..dd98f6f5 --- /dev/null +++ b/backend/internal/infrastructure/errors/types.go @@ -0,0 +1,114 @@ +// nolint:mnd +package errors + +import "net/http" + +// BadRequest new BadRequest error that is mapped to a 400 response. +func BadRequest(reason, message string) *ApplicationError { + return New(http.StatusBadRequest, reason, message) +} + +// IsBadRequest determines if err is an error which indicates a BadRequest error. +// It supports wrapped errors. +func IsBadRequest(err error) bool { + return Code(err) == http.StatusBadRequest +} + +// TooManyRequests new TooManyRequests error that is mapped to a 429 response. +func TooManyRequests(reason, message string) *ApplicationError { + return New(http.StatusTooManyRequests, reason, message) +} + +// IsTooManyRequests determines if err is an error which indicates a TooManyRequests error. +// It supports wrapped errors. +func IsTooManyRequests(err error) bool { + return Code(err) == http.StatusTooManyRequests +} + +// Unauthorized new Unauthorized error that is mapped to a 401 response. +func Unauthorized(reason, message string) *ApplicationError { + return New(http.StatusUnauthorized, reason, message) +} + +// IsUnauthorized determines if err is an error which indicates an Unauthorized error. +// It supports wrapped errors. +func IsUnauthorized(err error) bool { + return Code(err) == http.StatusUnauthorized +} + +// Forbidden new Forbidden error that is mapped to a 403 response. +func Forbidden(reason, message string) *ApplicationError { + return New(http.StatusForbidden, reason, message) +} + +// IsForbidden determines if err is an error which indicates a Forbidden error. +// It supports wrapped errors. +func IsForbidden(err error) bool { + return Code(err) == http.StatusForbidden +} + +// NotFound new NotFound error that is mapped to a 404 response. +func NotFound(reason, message string) *ApplicationError { + return New(http.StatusNotFound, reason, message) +} + +// IsNotFound determines if err is an error which indicates an NotFound error. +// It supports wrapped errors. +func IsNotFound(err error) bool { + return Code(err) == http.StatusNotFound +} + +// Conflict new Conflict error that is mapped to a 409 response. +func Conflict(reason, message string) *ApplicationError { + return New(http.StatusConflict, reason, message) +} + +// IsConflict determines if err is an error which indicates a Conflict error. +// It supports wrapped errors. +func IsConflict(err error) bool { + return Code(err) == http.StatusConflict +} + +// InternalServer new InternalServer error that is mapped to a 500 response. +func InternalServer(reason, message string) *ApplicationError { + return New(http.StatusInternalServerError, reason, message) +} + +// IsInternalServer determines if err is an error which indicates an Internal error. +// It supports wrapped errors. +func IsInternalServer(err error) bool { + return Code(err) == http.StatusInternalServerError +} + +// ServiceUnavailable new ServiceUnavailable error that is mapped to an HTTP 503 response. +func ServiceUnavailable(reason, message string) *ApplicationError { + return New(http.StatusServiceUnavailable, reason, message) +} + +// IsServiceUnavailable determines if err is an error which indicates an Unavailable error. +// It supports wrapped errors. +func IsServiceUnavailable(err error) bool { + return Code(err) == http.StatusServiceUnavailable +} + +// GatewayTimeout new GatewayTimeout error that is mapped to an HTTP 504 response. +func GatewayTimeout(reason, message string) *ApplicationError { + return New(http.StatusGatewayTimeout, reason, message) +} + +// IsGatewayTimeout determines if err is an error which indicates a GatewayTimeout error. +// It supports wrapped errors. +func IsGatewayTimeout(err error) bool { + return Code(err) == http.StatusGatewayTimeout +} + +// ClientClosed new ClientClosed error that is mapped to an HTTP 499 response. +func ClientClosed(reason, message string) *ApplicationError { + return New(499, reason, message) +} + +// IsClientClosed determines if err is an error which indicates a IsClientClosed error. +// It supports wrapped errors. +func IsClientClosed(err error) bool { + return Code(err) == 499 +} diff --git a/backend/internal/middleware/admin_auth.go b/backend/internal/middleware/admin_auth.go index 91c15875..8d1e819b 100644 --- a/backend/internal/middleware/admin_auth.go +++ b/backend/internal/middleware/admin_auth.go @@ -3,9 +3,11 @@ package middleware import ( "context" "crypto/subtle" + "errors" + "strings" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" - "strings" "github.com/gin-gonic/gin" ) @@ -96,7 +98,7 @@ func validateJWTForAdmin( // 验证 JWT token claims, err := authService.ValidateToken(token) if err != nil { - if err == service.ErrTokenExpired { + if errors.Is(err, service.ErrTokenExpired) { AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") return false } diff --git a/backend/internal/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go index d5843dc4..cd8dd7f6 100644 --- a/backend/internal/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -2,9 +2,11 @@ package middleware import ( "context" + "errors" + "strings" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" - "strings" "github.com/gin-gonic/gin" ) @@ -37,7 +39,7 @@ func JWTAuth(authService *service.AuthService, userRepo interface { // 验证token claims, err := authService.ValidateToken(tokenString) if err != nil { - if err == service.ErrTokenExpired { + if errors.Is(err, service.ErrTokenExpired) { AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired") return } diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index 6bbfee4c..e26d2531 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -4,14 +4,17 @@ import ( "math" "net/http" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/gin-gonic/gin" ) // Response 标准API响应格式 type Response struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Data any `json:"data,omitempty"` } // PaginatedData 分页数据格式(匹配前端期望) @@ -44,11 +47,36 @@ func Created(c *gin.Context, data any) { // Error 返回错误响应 func Error(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, Response{ - Code: statusCode, - Message: message, + Code: statusCode, + Message: message, + Reason: "", + Metadata: nil, }) } +// ErrorWithDetails returns an error response compatible with the existing envelope while +// optionally providing structured error fields (reason/metadata). +func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) { + c.JSON(statusCode, Response{ + Code: statusCode, + Message: message, + Reason: reason, + Metadata: metadata, + }) +} + +// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response. +// It returns true if an error was written. +func ErrorFrom(c *gin.Context, err error) bool { + if err == nil { + return false + } + + statusCode, status := infraerrors.ToHTTP(err) + ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) + return true +} + // BadRequest 返回400错误 func BadRequest(c *gin.Context, message string) { Error(c, http.StatusBadRequest, message) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go new file mode 100644 index 00000000..af6c2875 --- /dev/null +++ b/backend/internal/pkg/response/response_test.go @@ -0,0 +1,171 @@ +//go:build unit + +package response + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestErrorWithDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + message string + reason string + metadata map[string]string + want Response + }{ + { + name: "plain_error", + statusCode: http.StatusBadRequest, + message: "invalid request", + want: Response{ + Code: http.StatusBadRequest, + Message: "invalid request", + }, + }, + { + name: "structured_error", + statusCode: http.StatusForbidden, + message: "no access", + reason: "FORBIDDEN", + metadata: map[string]string{"k": "v"}, + want: Response{ + Code: http.StatusForbidden, + Message: "no access", + Reason: "FORBIDDEN", + Metadata: map[string]string{"k": "v"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata) + + require.Equal(t, tt.statusCode, w.Code) + + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + require.Equal(t, tt.want, got) + }) + } +} + +func TestErrorFrom(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + err error + wantWritten bool + wantHTTPCode int + wantBody Response + }{ + { + name: "nil_error", + err: nil, + wantWritten: false, + }, + { + name: "application_error", + err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), + wantWritten: true, + wantHTTPCode: http.StatusForbidden, + wantBody: Response{ + Code: http.StatusForbidden, + Message: "no access", + Reason: "FORBIDDEN", + Metadata: map[string]string{"scope": "admin"}, + }, + }, + { + name: "bad_request_error", + err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"), + wantWritten: true, + wantHTTPCode: http.StatusBadRequest, + wantBody: Response{ + Code: http.StatusBadRequest, + Message: "invalid request", + Reason: "INVALID_REQUEST", + }, + }, + { + name: "unauthorized_error", + err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"), + wantWritten: true, + wantHTTPCode: http.StatusUnauthorized, + wantBody: Response{ + Code: http.StatusUnauthorized, + Message: "unauthorized", + Reason: "UNAUTHORIZED", + }, + }, + { + name: "not_found_error", + err: infraerrors.NotFound("NOT_FOUND", "not found"), + wantWritten: true, + wantHTTPCode: http.StatusNotFound, + wantBody: Response{ + Code: http.StatusNotFound, + Message: "not found", + Reason: "NOT_FOUND", + }, + }, + { + name: "conflict_error", + err: infraerrors.Conflict("CONFLICT", "conflict"), + wantWritten: true, + wantHTTPCode: http.StatusConflict, + wantBody: Response{ + Code: http.StatusConflict, + Message: "conflict", + Reason: "CONFLICT", + }, + }, + { + name: "unknown_error_defaults_to_500", + err: errors.New("boom"), + wantWritten: true, + wantHTTPCode: http.StatusInternalServerError, + wantBody: Response{ + Code: http.StatusInternalServerError, + Message: "boom", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + written := ErrorFrom(c, tt.err) + require.Equal(t, tt.wantWritten, written) + + if !tt.wantWritten { + require.Equal(t, 200, w.Code) + require.Empty(t, w.Body.String()) + return + } + + require.Equal(t, tt.wantHTTPCode, w.Code) + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + require.Equal(t, tt.wantBody, got) + }) + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 0b9c6bf8..d07bc741 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -13,23 +13,23 @@ import ( "gorm.io/gorm/clause" ) -type AccountRepository struct { +type accountRepository struct { db *gorm.DB } -func NewAccountRepository(db *gorm.DB) *AccountRepository { - return &AccountRepository{db: db} +func NewAccountRepository(db *gorm.DB) service.AccountRepository { + return &accountRepository{db: db} } -func (r *AccountRepository) Create(ctx context.Context, account *model.Account) error { +func (r *accountRepository) Create(ctx context.Context, account *model.Account) error { return r.db.WithContext(ctx).Create(account).Error } -func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { +func (r *accountRepository) GetByID(ctx context.Context, id int64) (*model.Account, error) { var account model.Account err := r.db.WithContext(ctx).Preload("Proxy").Preload("AccountGroups.Group").First(&account, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) } // 填充 GroupIDs 和 Groups 虚拟字段 account.GroupIDs = make([]int64, 0, len(account.AccountGroups)) @@ -43,7 +43,7 @@ func (r *AccountRepository) GetByID(ctx context.Context, id int64) (*model.Accou return &account, nil } -func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { +func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error) { if crsAccountID == "" { return nil, nil } @@ -59,11 +59,11 @@ func (r *AccountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID return &account, nil } -func (r *AccountRepository) Update(ctx context.Context, account *model.Account) error { +func (r *accountRepository) Update(ctx context.Context, account *model.Account) error { return r.db.WithContext(ctx).Save(account).Error } -func (r *AccountRepository) Delete(ctx context.Context, id int64) error { +func (r *accountRepository) Delete(ctx context.Context, id int64) error { // 先删除账号与分组的绑定关系 if err := r.db.WithContext(ctx).Where("account_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { return err @@ -72,12 +72,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error } -func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "", "") } // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query -func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { var accounts []model.Account var total int64 @@ -131,7 +131,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params paginati }, nil } -func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { +func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) { var accounts []model.Account err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.account_id = accounts.id"). @@ -142,7 +142,7 @@ func (r *AccountRepository) ListByGroup(ctx context.Context, groupID int64) ([]m return accounts, err } -func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, error) { +func (r *accountRepository) ListActive(ctx context.Context) ([]model.Account, error) { var accounts []model.Account err := r.db.WithContext(ctx). Where("status = ?", model.StatusActive). @@ -152,12 +152,12 @@ func (r *AccountRepository) ListActive(ctx context.Context) ([]model.Account, er return accounts, err } -func (r *AccountRepository) UpdateLastUsed(ctx context.Context, id int64) error { +func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error { now := time.Now() return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).Update("last_used_at", now).Error } -func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { +func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error { return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Updates(map[string]any{ "status": model.StatusError, @@ -165,7 +165,7 @@ func (r *AccountRepository) SetError(ctx context.Context, id int64, errorMsg str }).Error } -func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { +func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { ag := &model.AccountGroup{ AccountID: accountID, GroupID: groupID, @@ -174,12 +174,12 @@ func (r *AccountRepository) AddToGroup(ctx context.Context, accountID, groupID i return r.db.WithContext(ctx).Create(ag).Error } -func (r *AccountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { +func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error { return r.db.WithContext(ctx).Where("account_id = ? AND group_id = ?", accountID, groupID). Delete(&model.AccountGroup{}).Error } -func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { +func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]model.Group, error) { var groups []model.Group err := r.db.WithContext(ctx). Joins("JOIN account_groups ON account_groups.group_id = groups.id"). @@ -188,7 +188,7 @@ func (r *AccountRepository) GetGroups(ctx context.Context, accountID int64) ([]m return groups, err } -func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { +func (r *accountRepository) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) { var accounts []model.Account err := r.db.WithContext(ctx). Where("platform = ? AND status = ?", platform, model.StatusActive). @@ -198,7 +198,7 @@ func (r *AccountRepository) ListByPlatform(ctx context.Context, platform string) return accounts, err } -func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { +func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { // 删除现有绑定 if err := r.db.WithContext(ctx).Where("account_id = ?", accountID).Delete(&model.AccountGroup{}).Error; err != nil { return err @@ -221,7 +221,7 @@ func (r *AccountRepository) BindGroups(ctx context.Context, accountID int64, gro } // ListSchedulable 获取所有可调度的账号 -func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { +func (r *accountRepository) ListSchedulable(ctx context.Context) ([]model.Account, error) { var accounts []model.Account now := time.Now() err := r.db.WithContext(ctx). @@ -235,7 +235,7 @@ func (r *AccountRepository) ListSchedulable(ctx context.Context) ([]model.Accoun } // ListSchedulableByGroupID 按组获取可调度的账号 -func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { +func (r *accountRepository) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) { var accounts []model.Account now := time.Now() err := r.db.WithContext(ctx). @@ -251,7 +251,7 @@ func (r *AccountRepository) ListSchedulableByGroupID(ctx context.Context, groupI } // ListSchedulableByPlatform 按平台获取可调度的账号 -func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { +func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error) { var accounts []model.Account now := time.Now() err := r.db.WithContext(ctx). @@ -266,7 +266,7 @@ func (r *AccountRepository) ListSchedulableByPlatform(ctx context.Context, platf } // ListSchedulableByGroupIDAndPlatform 按组和平台获取可调度的账号 -func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { +func (r *accountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error) { var accounts []model.Account now := time.Now() err := r.db.WithContext(ctx). @@ -283,7 +283,7 @@ func (r *AccountRepository) ListSchedulableByGroupIDAndPlatform(ctx context.Cont } // SetRateLimited 标记账号为限流状态(429) -func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { +func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { now := time.Now() return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Updates(map[string]any{ @@ -293,13 +293,13 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA } // SetOverloaded 标记账号为过载状态(529) -func (r *AccountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { +func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error { return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Update("overload_until", until).Error } // ClearRateLimit 清除账号的限流状态 -func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error { +func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Updates(map[string]any{ "rate_limited_at": nil, @@ -309,7 +309,7 @@ func (r *AccountRepository) ClearRateLimit(ctx context.Context, id int64) error } // UpdateSessionWindow 更新账号的5小时时间窗口信息 -func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { +func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { updates := map[string]any{ "session_window_status": status, } @@ -323,14 +323,14 @@ func (r *AccountRepository) UpdateSessionWindow(ctx context.Context, id int64, s } // SetSchedulable 设置账号的调度开关 -func (r *AccountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { +func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Update("schedulable", schedulable).Error } // UpdateExtra updates specific fields in account's Extra JSONB field // It merges the updates into existing Extra data without overwriting other fields -func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { +func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { if len(updates) == 0 { return nil } @@ -358,7 +358,7 @@ func (r *AccountRepository) UpdateExtra(ctx context.Context, id int64, updates m // BulkUpdate updates multiple accounts with the provided fields. // It merges credentials/extra JSONB fields instead of overwriting them. -func (r *AccountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { +func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { if len(ids) == 0 { return 0, nil } diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 76353d72..2f52ac5c 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -18,13 +18,13 @@ type AccountRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *AccountRepository + repo *accountRepository } func (s *AccountRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewAccountRepository(s.db) + s.repo = NewAccountRepository(s.db).(*accountRepository) } func TestAccountRepoSuite(t *testing.T) { @@ -167,7 +167,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Run(tt.name, func() { // 每个 case 重新获取隔离资源 db := testTx(s.T()) - repo := NewAccountRepository(db) + repo := NewAccountRepository(db).(*accountRepository) ctx := context.Background() tt.setup(db) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 73ab0b3a..e89fee75 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,51 +2,55 @@ package repository import ( "context" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" ) -type ApiKeyRepository struct { +type apiKeyRepository struct { db *gorm.DB } -func NewApiKeyRepository(db *gorm.DB) *ApiKeyRepository { - return &ApiKeyRepository{db: db} +func NewApiKeyRepository(db *gorm.DB) service.ApiKeyRepository { + return &apiKeyRepository{db: db} } -func (r *ApiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { - return r.db.WithContext(ctx).Create(key).Error +func (r *apiKeyRepository) Create(ctx context.Context, key *model.ApiKey) error { + err := r.db.WithContext(ctx).Create(key).Error + return translatePersistenceError(err, nil, service.ErrApiKeyExists) } -func (r *ApiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { +func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { var key model.ApiKey err := r.db.WithContext(ctx).Preload("User").Preload("Group").First(&key, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) } return &key, nil } -func (r *ApiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { +func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) { var apiKey model.ApiKey err := r.db.WithContext(ctx).Preload("User").Preload("Group").Where("key = ?", key).First(&apiKey).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrApiKeyNotFound, nil) } return &apiKey, nil } -func (r *ApiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { +func (r *apiKeyRepository) Update(ctx context.Context, key *model.ApiKey) error { return r.db.WithContext(ctx).Model(key).Select("name", "group_id", "status", "updated_at").Updates(key).Error } -func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { +func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error } -func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { +func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { var keys []model.ApiKey var total int64 @@ -73,19 +77,19 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param }, nil } -func (r *ApiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { +func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("user_id = ?", userID).Count(&count).Error return count, err } -func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { +func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("key = ?", key).Count(&count).Error return count > 0, err } -func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { +func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { var keys []model.ApiKey var total int64 @@ -113,7 +117,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par } // SearchApiKeys searches API keys by user ID and/or keyword (name) -func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { +func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) { var keys []model.ApiKey db := r.db.WithContext(ctx).Model(&model.ApiKey{}) @@ -135,7 +139,7 @@ func (r *ApiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyw } // ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil -func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { result := r.db.WithContext(ctx).Model(&model.ApiKey{}). Where("group_id = ?", groupID). Update("group_id", nil) @@ -143,7 +147,7 @@ func (r *ApiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID in } // CountByGroupID 获取分组的 API Key 数量 -func (r *ApiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.ApiKey{}).Where("group_id = ?", groupID).Count(&count).Error return count, err diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 00b332f9..7a599ede 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -16,13 +16,13 @@ type ApiKeyRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *ApiKeyRepository + repo *apiKeyRepository } func (s *ApiKeyRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewApiKeyRepository(s.db) + s.repo = NewApiKeyRepository(s.db).(*apiKeyRepository) } func TestApiKeyRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/error_translate.go b/backend/internal/repository/error_translate.go new file mode 100644 index 00000000..c70af510 --- /dev/null +++ b/backend/internal/repository/error_translate.go @@ -0,0 +1,40 @@ +package repository + +import ( + "errors" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + "gorm.io/gorm" +) + +func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error { + if err == nil { + return nil + } + + if notFound != nil && errors.Is(err, gorm.ErrRecordNotFound) { + return notFound.WithCause(err) + } + + if conflict != nil && isUniqueConstraintViolation(err) { + return conflict.WithCause(err) + } + + return err +} + +func isUniqueConstraintViolation(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, gorm.ErrDuplicatedKey) { + return true + } + + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "duplicate key") || + strings.Contains(msg, "unique constraint") || + strings.Contains(msg, "duplicate entry") +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 27647329..a2cb8e14 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -2,47 +2,52 @@ package repository import ( "context" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" + "gorm.io/gorm/clause" ) -type GroupRepository struct { +type groupRepository struct { db *gorm.DB } -func NewGroupRepository(db *gorm.DB) *GroupRepository { - return &GroupRepository{db: db} +func NewGroupRepository(db *gorm.DB) service.GroupRepository { + return &groupRepository{db: db} } -func (r *GroupRepository) Create(ctx context.Context, group *model.Group) error { - return r.db.WithContext(ctx).Create(group).Error +func (r *groupRepository) Create(ctx context.Context, group *model.Group) error { + err := r.db.WithContext(ctx).Create(group).Error + return translatePersistenceError(err, nil, service.ErrGroupExists) } -func (r *GroupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { +func (r *groupRepository) GetByID(ctx context.Context, id int64) (*model.Group, error) { var group model.Group err := r.db.WithContext(ctx).First(&group, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } return &group, nil } -func (r *GroupRepository) Update(ctx context.Context, group *model.Group) error { +func (r *groupRepository) Update(ctx context.Context, group *model.Group) error { return r.db.WithContext(ctx).Save(group).Error } -func (r *GroupRepository) Delete(ctx context.Context, id int64) error { +func (r *groupRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error } -func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", nil) } // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive -func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { var groups []model.Group var total int64 @@ -86,7 +91,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination }, nil } -func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) { +func (r *groupRepository) ListActive(ctx context.Context) ([]model.Group, error) { var groups []model.Group err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Order("id ASC").Find(&groups).Error if err != nil { @@ -100,7 +105,7 @@ func (r *GroupRepository) ListActive(ctx context.Context) ([]model.Group, error) return groups, nil } -func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { +func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) { var groups []model.Group err := r.db.WithContext(ctx).Where("status = ? AND platform = ?", model.StatusActive, platform).Order("id ASC").Find(&groups).Error if err != nil { @@ -114,25 +119,80 @@ func (r *GroupRepository) ListActiveByPlatform(ctx context.Context, platform str return groups, nil } -func (r *GroupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { +func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.Group{}).Where("name = ?", name).Count(&count).Error return count > 0, err } -func (r *GroupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { +func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.AccountGroup{}).Where("group_id = ?", groupID).Count(&count).Error return count, err } // DeleteAccountGroupsByGroupID 删除分组与账号的关联关系 -func (r *GroupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.AccountGroup{}) return result.RowsAffected, result.Error } -// DB 返回底层数据库连接,用于事务处理 -func (r *GroupRepository) DB() *gorm.DB { - return r.db +func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + group, err := r.GetByID(ctx, id) + if err != nil { + return nil, err + } + + var affectedUserIDs []int64 + if group.IsSubscriptionType() { + var subscriptions []model.UserSubscription + if err := r.db.WithContext(ctx). + Model(&model.UserSubscription{}). + Where("group_id = ?", id). + Select("user_id"). + Find(&subscriptions).Error; err != nil { + return nil, err + } + for _, sub := range subscriptions { + affectedUserIDs = append(affectedUserIDs, sub.UserID) + } + } + + err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 1. 删除订阅类型分组的订阅记录 + if group.IsSubscriptionType() { + if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { + return err + } + } + + // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil + if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { + return err + } + + // 3. 从 users.allowed_groups 数组中移除该分组 ID + if err := tx.Model(&model.User{}). + Where("? = ANY(allowed_groups)", id). + Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { + return err + } + + // 4. 删除 account_groups 中间表的数据 + if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { + return err + } + + // 5. 删除分组本身(带锁,避免并发写) + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Delete(&model.Group{}, id).Error; err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + return affectedUserIDs, nil } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index e4464657..85fd27b2 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -16,13 +16,13 @@ type GroupRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *GroupRepository + repo *groupRepository } func (s *GroupRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewGroupRepository(s.db) + s.repo = NewGroupRepository(s.db).(*groupRepository) } func TestGroupRepoSuite(t *testing.T) { @@ -234,11 +234,3 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { count, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } - -// --- DB --- - -func (s *GroupRepoSuite) TestDB() { - db := s.repo.DB() - s.Require().NotNil(db, "DB should return non-nil") - s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB") -} diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index b46b201c..590c6a61 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -2,47 +2,50 @@ package repository import ( "context" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" ) -type ProxyRepository struct { +type proxyRepository struct { db *gorm.DB } -func NewProxyRepository(db *gorm.DB) *ProxyRepository { - return &ProxyRepository{db: db} +func NewProxyRepository(db *gorm.DB) service.ProxyRepository { + return &proxyRepository{db: db} } -func (r *ProxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { +func (r *proxyRepository) Create(ctx context.Context, proxy *model.Proxy) error { return r.db.WithContext(ctx).Create(proxy).Error } -func (r *ProxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { +func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { var proxy model.Proxy err := r.db.WithContext(ctx).First(&proxy, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrProxyNotFound, nil) } return &proxy, nil } -func (r *ProxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { +func (r *proxyRepository) Update(ctx context.Context, proxy *model.Proxy) error { return r.db.WithContext(ctx).Save(proxy).Error } -func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { +func (r *proxyRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error } -func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { +func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists proxies with optional filtering by protocol, status, and search query -func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { +func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { var proxies []model.Proxy var total int64 @@ -81,14 +84,14 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination }, nil } -func (r *ProxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { +func (r *proxyRepository) ListActive(ctx context.Context) ([]model.Proxy, error) { var proxies []model.Proxy err := r.db.WithContext(ctx).Where("status = ?", model.StatusActive).Find(&proxies).Error return proxies, err } // ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists -func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { +func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.Proxy{}). Where("host = ? AND port = ? AND username = ? AND password = ?", host, port, username, password). @@ -100,7 +103,7 @@ func (r *ProxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, } // CountAccountsByProxyID returns the number of accounts using a specific proxy -func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { +func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.Account{}). Where("proxy_id = ?", proxyID). @@ -109,7 +112,7 @@ func (r *ProxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID in } // GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies -func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { +func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) { type result struct { ProxyID int64 `gorm:"column:proxy_id"` Count int64 `gorm:"column:count"` @@ -133,7 +136,7 @@ func (r *ProxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[i } // ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending -func (r *ProxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { +func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) { var proxies []model.Proxy err := r.db.WithContext(ctx). Where("status = ?", model.StatusActive). diff --git a/backend/internal/repository/proxy_repo_integration_test.go b/backend/internal/repository/proxy_repo_integration_test.go index 67c1825f..9e773398 100644 --- a/backend/internal/repository/proxy_repo_integration_test.go +++ b/backend/internal/repository/proxy_repo_integration_test.go @@ -17,13 +17,13 @@ type ProxyRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *ProxyRepository + repo *proxyRepository } func (s *ProxyRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewProxyRepository(s.db) + s.repo = NewProxyRepository(s.db).(*proxyRepository) } func TestProxyRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 98174cde..aa6e7010 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -2,57 +2,60 @@ package repository import ( "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "time" "gorm.io/gorm" ) -type RedeemCodeRepository struct { +type redeemCodeRepository struct { db *gorm.DB } -func NewRedeemCodeRepository(db *gorm.DB) *RedeemCodeRepository { - return &RedeemCodeRepository{db: db} +func NewRedeemCodeRepository(db *gorm.DB) service.RedeemCodeRepository { + return &redeemCodeRepository{db: db} } -func (r *RedeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { +func (r *redeemCodeRepository) Create(ctx context.Context, code *model.RedeemCode) error { return r.db.WithContext(ctx).Create(code).Error } -func (r *RedeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { +func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []model.RedeemCode) error { return r.db.WithContext(ctx).Create(&codes).Error } -func (r *RedeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { +func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { var code model.RedeemCode err := r.db.WithContext(ctx).First(&code, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) } return &code, nil } -func (r *RedeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { +func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { var redeemCode model.RedeemCode err := r.db.WithContext(ctx).Where("code = ?", code).First(&redeemCode).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrRedeemCodeNotFound, nil) } return &redeemCode, nil } -func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { +func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error } -func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { +func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists redeem codes with optional filtering by type, status, and search query -func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { +func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { var codes []model.RedeemCode var total int64 @@ -91,11 +94,11 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagin }, nil } -func (r *RedeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { +func (r *redeemCodeRepository) Update(ctx context.Context, code *model.RedeemCode) error { return r.db.WithContext(ctx).Save(code).Error } -func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error { +func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { now := time.Now() result := r.db.WithContext(ctx).Model(&model.RedeemCode{}). Where("id = ? AND status = ?", id, model.StatusUnused). @@ -108,13 +111,13 @@ func (r *RedeemCodeRepository) Use(ctx context.Context, id, userID int64) error return result.Error } if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound // 兑换码不存在或已被使用 + return service.ErrRedeemCodeUsed.WithCause(gorm.ErrRecordNotFound) } return nil } // ListByUser returns all redeem codes used by a specific user -func (r *RedeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { +func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) { var codes []model.RedeemCode if limit <= 0 { limit = 10 diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go index f39d6a51..5151f7e2 100644 --- a/backend/internal/repository/redeem_code_repo_integration_test.go +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -17,13 +18,13 @@ type RedeemCodeRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *RedeemCodeRepository + repo *redeemCodeRepository } func (s *RedeemCodeRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewRedeemCodeRepository(s.db) + s.repo = NewRedeemCodeRepository(s.db).(*redeemCodeRepository) } func TestRedeemCodeRepoSuite(t *testing.T) { @@ -195,7 +196,7 @@ func (s *RedeemCodeRepoSuite) TestUse_Idempotency() { // Second use should fail err = s.repo.Use(s.ctx, code.ID, user.ID) s.Require().Error(err, "Use expected error on second call") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrRedeemCodeUsed) } func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { @@ -204,7 +205,7 @@ func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() { err := s.repo.Use(s.ctx, code.ID, user.ID) s.Require().Error(err, "expected error for already used code") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrRedeemCodeUsed) } // --- ListByUser --- @@ -298,7 +299,7 @@ func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use") err = s.repo.Use(s.ctx, codeB.ID, user.ID) s.Require().Error(err, "Use expected error on second call") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrRedeemCodeUsed) codeA, err := s.repo.GetByCode(s.ctx, "CODEA") s.Require().NoError(err, "GetByCode") diff --git a/backend/internal/repository/repository.go b/backend/internal/repository/repository.go index 82cc46f0..b76c0d82 100644 --- a/backend/internal/repository/repository.go +++ b/backend/internal/repository/repository.go @@ -1,14 +1,16 @@ package repository +import "github.com/Wei-Shaw/sub2api/internal/service" + // Repositories 所有仓库的集合 type Repositories struct { - User *UserRepository - ApiKey *ApiKeyRepository - Group *GroupRepository - Account *AccountRepository - Proxy *ProxyRepository - RedeemCode *RedeemCodeRepository - UsageLog *UsageLogRepository - Setting *SettingRepository - UserSubscription *UserSubscriptionRepository + User service.UserRepository + ApiKey service.ApiKeyRepository + Group service.GroupRepository + Account service.AccountRepository + Proxy service.ProxyRepository + RedeemCode service.RedeemCodeRepository + UsageLog service.UsageLogRepository + Setting service.SettingRepository + UserSubscription service.UserSubscriptionRepository } diff --git a/backend/internal/repository/setting_repo.go b/backend/internal/repository/setting_repo.go index 165386c1..43dd65d4 100644 --- a/backend/internal/repository/setting_repo.go +++ b/backend/internal/repository/setting_repo.go @@ -2,35 +2,38 @@ package repository import ( "context" - "github.com/Wei-Shaw/sub2api/internal/model" "time" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/Wei-Shaw/sub2api/internal/model" + "gorm.io/gorm" "gorm.io/gorm/clause" ) // SettingRepository 系统设置数据访问层 -type SettingRepository struct { +type settingRepository struct { db *gorm.DB } // NewSettingRepository 创建系统设置仓库实例 -func NewSettingRepository(db *gorm.DB) *SettingRepository { - return &SettingRepository{db: db} +func NewSettingRepository(db *gorm.DB) service.SettingRepository { + return &settingRepository{db: db} } // Get 根据Key获取设置值 -func (r *SettingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { +func (r *settingRepository) Get(ctx context.Context, key string) (*model.Setting, error) { var setting model.Setting err := r.db.WithContext(ctx).Where("key = ?", key).First(&setting).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrSettingNotFound, nil) } return &setting, nil } // GetValue 获取设置值字符串 -func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, error) { +func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) { setting, err := r.Get(ctx, key) if err != nil { return "", err @@ -39,7 +42,7 @@ func (r *SettingRepository) GetValue(ctx context.Context, key string) (string, e } // Set 设置值(存在则更新,不存在则创建) -func (r *SettingRepository) Set(ctx context.Context, key, value string) error { +func (r *settingRepository) Set(ctx context.Context, key, value string) error { setting := &model.Setting{ Key: key, Value: value, @@ -53,7 +56,7 @@ func (r *SettingRepository) Set(ctx context.Context, key, value string) error { } // GetMultiple 批量获取设置 -func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { +func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { var settings []model.Setting err := r.db.WithContext(ctx).Where("key IN ?", keys).Find(&settings).Error if err != nil { @@ -68,7 +71,7 @@ func (r *SettingRepository) GetMultiple(ctx context.Context, keys []string) (map } // SetMultiple 批量设置值 -func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { +func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { for key, value := range settings { setting := &model.Setting{ @@ -88,7 +91,7 @@ func (r *SettingRepository) SetMultiple(ctx context.Context, settings map[string } // GetAll 获取所有设置 -func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, error) { +func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) { var settings []model.Setting err := r.db.WithContext(ctx).Find(&settings).Error if err != nil { @@ -103,6 +106,6 @@ func (r *SettingRepository) GetAll(ctx context.Context) (map[string]string, erro } // Delete 删除设置 -func (r *SettingRepository) Delete(ctx context.Context, key string) error { +func (r *settingRepository) Delete(ctx context.Context, key string) error { return r.db.WithContext(ctx).Where("key = ?", key).Delete(&model.Setting{}).Error } diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index b42cacd7..e637942e 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -6,6 +6,7 @@ import ( "context" "testing" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" "gorm.io/gorm" ) @@ -14,13 +15,13 @@ type SettingRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *SettingRepository + repo *settingRepository } func (s *SettingRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewSettingRepository(s.db) + s.repo = NewSettingRepository(s.db).(*settingRepository) } func TestSettingRepoSuite(t *testing.T) { @@ -45,7 +46,7 @@ func (s *SettingRepoSuite) TestSet_Upsert() { func (s *SettingRepoSuite) TestGetValue_Missing() { _, err := s.repo.GetValue(s.ctx, "nonexistent") s.Require().Error(err, "expected error for missing key") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrSettingNotFound) } func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() { @@ -86,7 +87,7 @@ func (s *SettingRepoSuite) TestDelete() { s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete") _, err := s.repo.GetValue(s.ctx, "todelete") s.Require().Error(err, "expected missing key error after Delete") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrSettingNotFound) } func (s *SettingRepoSuite) TestDelete_Idempotent() { diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index eeee8679..038719ae 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -2,25 +2,28 @@ package repository import ( "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" - "time" "gorm.io/gorm" ) -type UsageLogRepository struct { +type usageLogRepository struct { db *gorm.DB } -func NewUsageLogRepository(db *gorm.DB) *UsageLogRepository { - return &UsageLogRepository{db: db} +func NewUsageLogRepository(db *gorm.DB) service.UsageLogRepository { + return &usageLogRepository{db: db} } // getPerformanceStats 获取 RPM 和 TPM(近5分钟平均值,可选按用户过滤) -func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) { +func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int64) (rpm, tpm int64) { fiveMinutesAgo := time.Now().Add(-5 * time.Minute) var perfStats struct { RequestCount int64 `gorm:"column:request_count"` @@ -43,20 +46,20 @@ func (r *UsageLogRepository) getPerformanceStats(ctx context.Context, userID int return perfStats.RequestCount / 5, perfStats.TokenCount / 5 } -func (r *UsageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { +func (r *usageLogRepository) Create(ctx context.Context, log *model.UsageLog) error { return r.db.WithContext(ctx).Create(log).Error } -func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { +func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { var log model.UsageLog err := r.db.WithContext(ctx).First(&log, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrUsageLogNotFound, nil) } return &log, nil } -func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -83,7 +86,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param }, nil } -func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -120,7 +123,7 @@ type UserStats struct { CacheReadTokens int64 `json:"cache_read_tokens"` } -func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { +func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { var stats UserStats err := r.db.WithContext(ctx).Model(&model.UsageLog{}). Select(` @@ -139,7 +142,7 @@ func (r *UsageLogRepository) GetUserStats(ctx context.Context, userID int64, sta // DashboardStats 仪表盘统计 type DashboardStats = usagestats.DashboardStats -func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { +func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { var stats DashboardStats today := timezone.Today() @@ -260,7 +263,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS return &stats, nil } -func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -287,7 +290,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, }, nil } -func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). @@ -296,7 +299,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID return logs, nil, err } -func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). @@ -305,7 +308,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe return logs, nil, err } -func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). @@ -314,7 +317,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco return logs, nil, err } -func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). @@ -323,12 +326,12 @@ func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelN return logs, nil, err } -func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { +func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error } // GetAccountTodayStats 获取账号今日统计 -func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { +func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { today := timezone.Today() var stats struct { @@ -358,7 +361,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID } // GetAccountWindowStats 获取账号时间窗口内的统计 -func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { +func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { var stats struct { Requests int64 `gorm:"column:requests"` Tokens int64 `gorm:"column:tokens"` @@ -398,7 +401,7 @@ type UserUsageTrendPoint = usagestats.UserUsageTrendPoint type ApiKeyUsageTrendPoint = usagestats.ApiKeyUsageTrendPoint // GetApiKeyUsageTrend returns usage trend data grouped by API key and date -func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { +func (r *usageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]ApiKeyUsageTrendPoint, error) { var results []ApiKeyUsageTrendPoint // Choose date format based on granularity @@ -442,7 +445,7 @@ func (r *UsageLogRepository) GetApiKeyUsageTrend(ctx context.Context, startTime, } // GetUserUsageTrend returns usage trend data grouped by user and date -func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { +func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]UserUsageTrendPoint, error) { var results []UserUsageTrendPoint // Choose date format based on granularity @@ -491,7 +494,7 @@ func (r *UsageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, e type UserDashboardStats = usagestats.UserDashboardStats // GetUserDashboardStats 获取用户专属的仪表盘统计 -func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { +func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { var stats UserDashboardStats today := timezone.Today() @@ -578,7 +581,7 @@ func (r *UsageLogRepository) GetUserDashboardStats(ctx context.Context, userID i } // GetUserUsageTrendByUserID 获取指定用户的使用趋势 -func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { +func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]TrendDataPoint, error) { var results []TrendDataPoint var dateFormat string @@ -612,7 +615,7 @@ func (r *UsageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, user } // GetUserModelStats 获取指定用户的模型统计 -func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { +func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]ModelStat, error) { var results []ModelStat err := r.db.WithContext(ctx).Model(&model.UsageLog{}). @@ -641,7 +644,7 @@ func (r *UsageLogRepository) GetUserModelStats(ctx context.Context, userID int64 type UsageLogFilters = usagestats.UsageLogFilters // ListWithFilters lists usage logs with optional filters (for admin) -func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { +func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -692,7 +695,7 @@ type UsageStats = usagestats.UsageStats type BatchUserUsageStats = usagestats.BatchUserUsageStats // GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { if len(userIDs) == 0 { return make(map[int64]*BatchUserUsageStats), nil } @@ -752,7 +755,7 @@ func (r *UsageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs type BatchApiKeyUsageStats = usagestats.BatchApiKeyUsageStats // GetBatchApiKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { +func (r *usageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchApiKeyUsageStats, error) { if len(apiKeyIDs) == 0 { return make(map[int64]*BatchApiKeyUsageStats), nil } @@ -809,7 +812,7 @@ func (r *UsageLogRepository) GetBatchApiKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters -func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]TrendDataPoint, error) { var results []TrendDataPoint var dateFormat string @@ -848,7 +851,7 @@ func (r *UsageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional user/api_key filters -func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]ModelStat, error) { var results []ModelStat db := r.db.WithContext(ctx).Model(&model.UsageLog{}). @@ -882,7 +885,7 @@ func (r *UsageLogRepository) GetModelStatsWithFilters(ctx context.Context, start } // GetGlobalStats gets usage statistics for all users within a time range -func (r *UsageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { +func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { var stats struct { TotalRequests int64 `gorm:"column:total_requests"` TotalInputTokens int64 `gorm:"column:total_input_tokens"` @@ -932,7 +935,7 @@ type AccountUsageSummary = usagestats.AccountUsageSummary type AccountUsageStatsResponse = usagestats.AccountUsageStatsResponse // GetAccountUsageStats returns comprehensive usage statistics for an account over a time range -func (r *UsageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { +func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*AccountUsageStatsResponse, error) { daysCount := int(endTime.Sub(startTime).Hours()/24) + 1 if daysCount <= 0 { daysCount = 30 diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 76265d31..6423de71 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -19,13 +19,13 @@ type UsageLogRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *UsageLogRepository + repo *usageLogRepository } func (s *UsageLogRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewUsageLogRepository(s.db) + s.repo = NewUsageLogRepository(s.db).(*usageLogRepository) } func TestUsageLogRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 46388e8e..c87e3838 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -2,56 +2,61 @@ package repository import ( "context" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "gorm.io/gorm" ) -type UserRepository struct { +type userRepository struct { db *gorm.DB } -func NewUserRepository(db *gorm.DB) *UserRepository { - return &UserRepository{db: db} +func NewUserRepository(db *gorm.DB) service.UserRepository { + return &userRepository{db: db} } -func (r *UserRepository) Create(ctx context.Context, user *model.User) error { - return r.db.WithContext(ctx).Create(user).Error +func (r *userRepository) Create(ctx context.Context, user *model.User) error { + err := r.db.WithContext(ctx).Create(user).Error + return translatePersistenceError(err, nil, service.ErrEmailExists) } -func (r *UserRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { +func (r *userRepository) GetByID(ctx context.Context, id int64) (*model.User, error) { var user model.User err := r.db.WithContext(ctx).First(&user, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } return &user, nil } -func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { +func (r *userRepository) GetByEmail(ctx context.Context, email string) (*model.User, error) { var user model.User err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } return &user, nil } -func (r *UserRepository) Update(ctx context.Context, user *model.User) error { - return r.db.WithContext(ctx).Save(user).Error +func (r *userRepository) Update(ctx context.Context, user *model.User) error { + err := r.db.WithContext(ctx).Save(user).Error + return translatePersistenceError(err, nil, service.ErrEmailExists) } -func (r *UserRepository) Delete(ctx context.Context, id int64) error { +func (r *userRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.User{}, id).Error } -func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { +func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists users with optional filtering by status, role, and search query -func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { +func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { var users []model.User var total int64 @@ -120,13 +125,13 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination. }, nil } -func (r *UserRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { +func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). Update("balance", gorm.Expr("balance + ?", amount)).Error } // DeductBalance 扣减用户余额,仅当余额充足时执行 -func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { +func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error { result := r.db.WithContext(ctx).Model(&model.User{}). Where("id = ? AND balance >= ?", id, amount). Update("balance", gorm.Expr("balance - ?", amount)) @@ -134,17 +139,17 @@ func (r *UserRepository) DeductBalance(ctx context.Context, id int64, amount flo return result.Error } if result.RowsAffected == 0 { - return gorm.ErrRecordNotFound // 余额不足或用户不存在 + return service.ErrInsufficientBalance } return nil } -func (r *UserRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { +func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error { return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id). Update("concurrency", gorm.Expr("concurrency + ?", amount)).Error } -func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { +func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.User{}).Where("email = ?", email).Count(&count).Error return count > 0, err @@ -152,7 +157,7 @@ func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, // RemoveGroupFromAllowedGroups 从所有用户的 allowed_groups 数组中移除指定的分组ID // 使用 PostgreSQL 的 array_remove 函数 -func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { +func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { result := r.db.WithContext(ctx).Model(&model.User{}). Where("? = ANY(allowed_groups)", groupID). Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) @@ -160,14 +165,14 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group } // GetFirstAdmin 获取第一个管理员用户(用于 Admin API Key 认证) -func (r *UserRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { +func (r *userRepository) GetFirstAdmin(ctx context.Context) (*model.User, error) { var user model.User err := r.db.WithContext(ctx). Where("role = ? AND status = ?", model.RoleAdmin, model.StatusActive). Order("id ASC"). First(&user).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) } return &user, nil } diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go index 7efe2d5c..020e2e32 100644 --- a/backend/internal/repository/user_repo_integration_test.go +++ b/backend/internal/repository/user_repo_integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" "github.com/stretchr/testify/suite" "gorm.io/gorm" @@ -18,13 +19,13 @@ type UserRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *UserRepository + repo *userRepository } func (s *UserRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewUserRepository(s.db) + s.repo = NewUserRepository(s.db).(*userRepository) } func TestUserRepoSuite(t *testing.T) { @@ -247,7 +248,7 @@ func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() { err := s.repo.DeductBalance(s.ctx, user.ID, 999) s.Require().Error(err, "expected error for insufficient balance") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound) + s.Require().ErrorIs(err, service.ErrInsufficientBalance) } func (s *UserRepoSuite) TestDeductBalance_ExactAmount() { @@ -432,7 +433,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() { err = s.repo.DeductBalance(s.ctx, user1.ID, 999) s.Require().Error(err, "DeductBalance expected error for insufficient balance") - s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error") + s.Require().ErrorIs(err, service.ErrInsufficientBalance, "DeductBalance unexpected error") s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency") got5, err := s.repo.GetByID(s.ctx, user1.ID) diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index eaf9641a..7fea8fb0 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -6,27 +6,29 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" "gorm.io/gorm" ) // UserSubscriptionRepository 用户订阅仓库 -type UserSubscriptionRepository struct { +type userSubscriptionRepository struct { db *gorm.DB } // NewUserSubscriptionRepository 创建用户订阅仓库 -func NewUserSubscriptionRepository(db *gorm.DB) *UserSubscriptionRepository { - return &UserSubscriptionRepository{db: db} +func NewUserSubscriptionRepository(db *gorm.DB) service.UserSubscriptionRepository { + return &userSubscriptionRepository{db: db} } // Create 创建订阅 -func (r *UserSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { - return r.db.WithContext(ctx).Create(sub).Error +func (r *userSubscriptionRepository) Create(ctx context.Context, sub *model.UserSubscription) error { + err := r.db.WithContext(ctx).Create(sub).Error + return translatePersistenceError(err, nil, service.ErrSubscriptionAlreadyExists) } // GetByID 根据ID获取订阅 -func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { +func (r *userSubscriptionRepository) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { var sub model.UserSubscription err := r.db.WithContext(ctx). Preload("User"). @@ -34,26 +36,26 @@ func (r *UserSubscriptionRepository) GetByID(ctx context.Context, id int64) (*mo Preload("AssignedByUser"). First(&sub, id).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } return &sub, nil } // GetByUserIDAndGroupID 根据用户ID和分组ID获取订阅 -func (r *UserSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { +func (r *userSubscriptionRepository) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var sub model.UserSubscription err := r.db.WithContext(ctx). Preload("Group"). Where("user_id = ? AND group_id = ?", userID, groupID). First(&sub).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } return &sub, nil } // GetActiveByUserIDAndGroupID 获取用户对特定分组的有效订阅 -func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { +func (r *userSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { var sub model.UserSubscription err := r.db.WithContext(ctx). Preload("Group"). @@ -61,24 +63,24 @@ func (r *UserSubscriptionRepository) GetActiveByUserIDAndGroupID(ctx context.Con userID, groupID, model.SubscriptionStatusActive, time.Now()). First(&sub).Error if err != nil { - return nil, err + return nil, translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } return &sub, nil } // Update 更新订阅 -func (r *UserSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { +func (r *userSubscriptionRepository) Update(ctx context.Context, sub *model.UserSubscription) error { sub.UpdatedAt = time.Now() return r.db.WithContext(ctx).Save(sub).Error } // Delete 删除订阅 -func (r *UserSubscriptionRepository) Delete(ctx context.Context, id int64) error { +func (r *userSubscriptionRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.UserSubscription{}, id).Error } // ListByUserID 获取用户的所有订阅 -func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { +func (r *userSubscriptionRepository) ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []model.UserSubscription err := r.db.WithContext(ctx). Preload("Group"). @@ -89,7 +91,7 @@ func (r *UserSubscriptionRepository) ListByUserID(ctx context.Context, userID in } // ListActiveByUserID 获取用户的所有有效订阅 -func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { +func (r *userSubscriptionRepository) ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) { var subs []model.UserSubscription err := r.db.WithContext(ctx). Preload("Group"). @@ -101,7 +103,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use } // ListByGroupID 获取分组的所有订阅(分页) -func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []model.UserSubscription var total int64 @@ -136,7 +138,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID } // List 获取所有订阅(分页,支持筛选) -func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []model.UserSubscription var total int64 @@ -182,7 +184,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination } // IncrementUsage 增加使用量 -func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { +func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -194,7 +196,7 @@ func (r *UserSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 } // ResetDailyUsage 重置日使用量 -func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { +func (r *userSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -205,7 +207,7 @@ func (r *UserSubscriptionRepository) ResetDailyUsage(ctx context.Context, id int } // ResetWeeklyUsage 重置周使用量 -func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { +func (r *userSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -216,7 +218,7 @@ func (r *UserSubscriptionRepository) ResetWeeklyUsage(ctx context.Context, id in } // ResetMonthlyUsage 重置月使用量 -func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { +func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -227,7 +229,7 @@ func (r *UserSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i } // ActivateWindows 激活所有窗口(首次使用时) -func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { +func (r *userSubscriptionRepository) ActivateWindows(ctx context.Context, id int64, activateTime time.Time) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -239,7 +241,7 @@ func (r *UserSubscriptionRepository) ActivateWindows(ctx context.Context, id int } // UpdateStatus 更新订阅状态 -func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { +func (r *userSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, status string) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -249,7 +251,7 @@ func (r *UserSubscriptionRepository) UpdateStatus(ctx context.Context, id int64, } // ExtendExpiry 延长订阅过期时间 -func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { +func (r *userSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, newExpiresAt time.Time) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -259,7 +261,7 @@ func (r *UserSubscriptionRepository) ExtendExpiry(ctx context.Context, id int64, } // UpdateNotes 更新订阅备注 -func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { +func (r *userSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, notes string) error { return r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("id = ?", id). Updates(map[string]any{ @@ -269,7 +271,7 @@ func (r *UserSubscriptionRepository) UpdateNotes(ctx context.Context, id int64, } // ListExpired 获取所有已过期但状态仍为active的订阅 -func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { +func (r *userSubscriptionRepository) ListExpired(ctx context.Context) ([]model.UserSubscription, error) { var subs []model.UserSubscription err := r.db.WithContext(ctx). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). @@ -278,7 +280,7 @@ func (r *UserSubscriptionRepository) ListExpired(ctx context.Context) ([]model.U } // BatchUpdateExpiredStatus 批量更新过期订阅状态 -func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { +func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { result := r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("status = ? AND expires_at <= ?", model.SubscriptionStatusActive, time.Now()). Updates(map[string]any{ @@ -289,7 +291,7 @@ func (r *UserSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Contex } // ExistsByUserIDAndGroupID 检查用户是否已有该分组的订阅 -func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { +func (r *userSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("user_id = ? AND group_id = ?", userID, groupID). @@ -298,7 +300,7 @@ func (r *UserSubscriptionRepository) ExistsByUserIDAndGroupID(ctx context.Contex } // CountByGroupID 获取分组的订阅数量 -func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *userSubscriptionRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("group_id = ?", groupID). @@ -307,7 +309,7 @@ func (r *UserSubscriptionRepository) CountByGroupID(ctx context.Context, groupID } // CountActiveByGroupID 获取分组的有效订阅数量 -func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *userSubscriptionRepository) CountActiveByGroupID(ctx context.Context, groupID int64) (int64, error) { var count int64 err := r.db.WithContext(ctx).Model(&model.UserSubscription{}). Where("group_id = ? AND status = ? AND expires_at > ?", @@ -317,7 +319,7 @@ func (r *UserSubscriptionRepository) CountActiveByGroupID(ctx context.Context, g } // DeleteByGroupID 删除分组相关的所有订阅记录 -func (r *UserSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { +func (r *userSubscriptionRepository) DeleteByGroupID(ctx context.Context, groupID int64) (int64, error) { result := r.db.WithContext(ctx).Where("group_id = ?", groupID).Delete(&model.UserSubscription{}) return result.RowsAffected, result.Error } diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 9cecf4e8..e6c1c850 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -17,13 +17,13 @@ type UserSubscriptionRepoSuite struct { suite.Suite ctx context.Context db *gorm.DB - repo *UserSubscriptionRepository + repo *userSubscriptionRepository } func (s *UserSubscriptionRepoSuite) SetupTest() { s.ctx = context.Background() s.db = testTx(s.T()) - s.repo = NewUserSubscriptionRepository(s.db) + s.repo = NewUserSubscriptionRepository(s.db).(*userSubscriptionRepository) } func TestUserSubscriptionRepoSuite(t *testing.T) { diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index b4aff0c0..6dff6407 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,7 +1,6 @@ package repository import ( - "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/wire" ) @@ -37,15 +36,4 @@ var ProviderSet = wire.NewSet( NewClaudeOAuthClient, NewHTTPUpstream, NewOpenAIOAuthClient, - - // Bind concrete repositories to service port interfaces - wire.Bind(new(service.UserRepository), new(*UserRepository)), - wire.Bind(new(service.ApiKeyRepository), new(*ApiKeyRepository)), - wire.Bind(new(service.GroupRepository), new(*GroupRepository)), - wire.Bind(new(service.AccountRepository), new(*AccountRepository)), - wire.Bind(new(service.ProxyRepository), new(*ProxyRepository)), - wire.Bind(new(service.RedeemCodeRepository), new(*RedeemCodeRepository)), - wire.Bind(new(service.UsageLogRepository), new(*UsageLogRepository)), - wire.Bind(new(service.SettingRepository), new(*SettingRepository)), - wire.Bind(new(service.UserSubscriptionRepository), new(*UserSubscriptionRepository)), ) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 705f8a6d..e968ea57 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -2,17 +2,16 @@ package service import ( "context" - "errors" "fmt" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "gorm.io/gorm" ) var ( - ErrAccountNotFound = errors.New("account not found") + ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found") ) type AccountRepository interface { @@ -106,9 +105,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( for _, groupID := range req.GroupIDs { _, err := s.groupRepo.GetByID(ctx, groupID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("group %d not found", groupID) - } return nil, fmt.Errorf("get group: %w", err) } } @@ -145,9 +141,6 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) ( func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrAccountNotFound - } return nil, fmt.Errorf("get account: %w", err) } return account, nil @@ -184,9 +177,6 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrAccountNotFound - } return nil, fmt.Errorf("get account: %w", err) } @@ -229,9 +219,6 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount for _, groupID := range *req.GroupIDs { _, err := s.groupRepo.GetByID(ctx, groupID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("group %d not found", groupID) - } return nil, fmt.Errorf("get group: %w", err) } } @@ -249,9 +236,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { // 检查账号是否存在 _, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrAccountNotFound - } return fmt.Errorf("get account: %w", err) } @@ -266,9 +250,6 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrAccountNotFound - } return fmt.Errorf("get account: %w", err) } @@ -294,9 +275,6 @@ func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error { func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return "", ErrAccountNotFound - } return "", fmt.Errorf("get account: %w", err) } @@ -307,9 +285,6 @@ func (s *AccountService) GetCredential(ctx context.Context, id int64, key string func (s *AccountService) TestCredentials(ctx context.Context, id int64) error { account, err := s.accountRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrAccountNotFound - } return fmt.Errorf("get account: %w", err) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 3000d87a..964711b7 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -9,7 +9,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "gorm.io/gorm" ) // AdminService interface defines admin management operations @@ -550,61 +549,7 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { - // 先获取分组信息,检查是否存在 - group, err := s.groupRepo.GetByID(ctx, id) - if err != nil { - return fmt.Errorf("group not found: %w", err) - } - - // 订阅类型分组:先获取受影响的用户ID列表(用于事务后失效缓存) - var affectedUserIDs []int64 - if group.IsSubscriptionType() && s.billingCacheService != nil { - var subscriptions []model.UserSubscription - if err := s.groupRepo.DB().WithContext(ctx). - Where("group_id = ?", id). - Select("user_id"). - Find(&subscriptions).Error; err == nil { - for _, sub := range subscriptions { - affectedUserIDs = append(affectedUserIDs, sub.UserID) - } - } - } - - // 使用事务处理所有级联删除 - db := s.groupRepo.DB() - err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - // 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录 - if group.IsSubscriptionType() { - if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil { - return fmt.Errorf("delete user subscriptions: %w", err) - } - } - - // 2. 将 api_keys 中绑定该分组的 group_id 设为 nil(任何类型的分组都需要) - if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil { - return fmt.Errorf("clear api key group_id: %w", err) - } - - // 3. 从 users.allowed_groups 数组中移除该分组 ID - if err := tx.Model(&model.User{}). - Where("? = ANY(allowed_groups)", id). - Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil { - return fmt.Errorf("remove from allowed_groups: %w", err) - } - - // 4. 删除 account_groups 中间表的数据 - if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil { - return fmt.Errorf("delete account groups: %w", err) - } - - // 5. 删除分组本身 - if err := tx.Delete(&model.Group{}, id).Error; err != nil { - return fmt.Errorf("delete group: %w", err) - } - - return nil - }) - + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index f1d7c9f3..788d226e 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -9,20 +9,20 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/redis/go-redis/v9" - "gorm.io/gorm" ) var ( - ErrApiKeyNotFound = errors.New("api key not found") - ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") - ErrApiKeyExists = errors.New("api key already exists") - ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") - ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") - ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") + ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found") + ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group") + ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists") + ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") + ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") + ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ) const ( @@ -183,9 +183,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK // 验证用户存在 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } @@ -193,9 +190,6 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK if req.GroupID != nil { group, err := s.groupRepo.GetByID(ctx, *req.GroupID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, errors.New("group not found") - } return nil, fmt.Errorf("get group: %w", err) } @@ -269,9 +263,6 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrApiKeyNotFound - } return nil, fmt.Errorf("get api key: %w", err) } return apiKey, nil @@ -285,9 +276,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey // 这里可以添加Redis缓存逻辑,暂时直接查询数据库 apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrApiKeyNotFound - } return nil, fmt.Errorf("get api key: %w", err) } @@ -304,9 +292,6 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrApiKeyNotFound - } return nil, fmt.Errorf("get api key: %w", err) } @@ -329,9 +314,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req group, err := s.groupRepo.GetByID(ctx, *req.GroupID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, errors.New("group not found") - } return nil, fmt.Errorf("get group: %w", err) } @@ -361,9 +343,6 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error { apiKey, err := s.apiKeyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrApiKeyNotFound - } return fmt.Errorf("get api key: %w", err) } @@ -394,15 +373,12 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.Api // 检查API Key状态 if !apiKey.IsActive() { - return nil, nil, errors.New("api key is not active") + return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active") } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, apiKey.UserID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil, ErrUserNotFound - } return nil, nil, fmt.Errorf("get user: %w", err) } @@ -436,9 +412,6 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } @@ -450,7 +423,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([ // 获取用户的所有有效订阅 activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err != nil { return nil, fmt.Errorf("list active subscriptions: %w", err) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 8c6d78d5..e6d29e09 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -8,22 +8,22 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" ) var ( - ErrInvalidCredentials = errors.New("invalid email or password") - ErrUserNotActive = errors.New("user is not active") - ErrEmailExists = errors.New("email already exists") - ErrInvalidToken = errors.New("invalid token") - ErrTokenExpired = errors.New("token has expired") - ErrEmailVerifyRequired = errors.New("email verification is required") - ErrRegDisabled = errors.New("registration is currently disabled") - ErrServiceUnavailable = errors.New("service temporarily unavailable") + ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") + ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") + ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") + ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ) // JWTClaims JWT载荷数据 @@ -255,7 +255,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string // 查找用户 user, err := s.userRepo.GetByEmail(ctx, email) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, ErrUserNotFound) { return "", nil, ErrInvalidCredentials } // 记录数据库错误但不暴露给用户 @@ -357,7 +357,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( // 获取最新的用户信息 user, err := s.userRepo.GetByID(ctx, claims.UserID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, ErrUserNotFound) { return "", ErrInvalidToken } log.Printf("[Auth] Database error refreshing token: %v", err) diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 1a18ff12..70741d56 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -2,11 +2,11 @@ package service import ( "context" - "errors" "fmt" "log" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" ) @@ -14,7 +14,7 @@ import ( // 注:ErrInsufficientBalance在redeem_service.go中定义 // 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义 var ( - ErrSubscriptionInvalid = errors.New("subscription is invalid or expired") + ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 38da392e..27c68c52 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -4,21 +4,21 @@ import ( "context" "crypto/rand" "crypto/tls" - "errors" "fmt" "math/big" "net/smtp" "strconv" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" ) var ( - ErrEmailNotConfigured = errors.New("email service not configured") - ErrInvalidVerifyCode = errors.New("invalid or expired verification code") - ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code") - ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code") + ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured") + ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code") + ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code") + ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code") ) // EmailCache defines cache operations for email service diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index de351a0e..ea9bd24d 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -2,17 +2,16 @@ package service import ( "context" - "errors" "fmt" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "gorm.io/gorm" ) var ( - ErrGroupNotFound = errors.New("group not found") - ErrGroupExists = errors.New("group name already exists") + ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found") + ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists") ) type GroupRepository interface { @@ -20,6 +19,7 @@ type GroupRepository interface { GetByID(ctx context.Context, id int64) (*model.Group, error) Update(ctx context.Context, group *model.Group) error Delete(ctx context.Context, id int64) error + DeleteCascade(ctx context.Context, id int64) ([]int64, error) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) @@ -29,8 +29,6 @@ type GroupRepository interface { ExistsByName(ctx context.Context, name string) (bool, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) - - DB() *gorm.DB } // CreateGroupRequest 创建分组请求 @@ -93,9 +91,6 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrGroupNotFound - } return nil, fmt.Errorf("get group: %w", err) } return group, nil @@ -123,9 +118,6 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) { func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrGroupNotFound - } return nil, fmt.Errorf("get group: %w", err) } @@ -170,9 +162,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { // 检查分组是否存在 _, err := s.groupRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrGroupNotFound - } return fmt.Errorf("get group: %w", err) } @@ -187,9 +176,6 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrGroupNotFound - } return nil, fmt.Errorf("get group: %w", err) } diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index 2c16a045..28ade11f 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -2,16 +2,15 @@ package service import ( "context" - "errors" "fmt" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "gorm.io/gorm" ) var ( - ErrProxyNotFound = errors.New("proxy not found") + ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found") ) type ProxyRepository interface { @@ -86,9 +85,6 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProxyNotFound - } return nil, fmt.Errorf("get proxy: %w", err) } return proxy, nil @@ -116,9 +112,6 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) { func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProxyNotFound - } return nil, fmt.Errorf("get proxy: %w", err) } @@ -163,9 +156,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error { // 检查代理是否存在 _, err := s.proxyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProxyNotFound - } return fmt.Errorf("get proxy: %w", err) } @@ -180,9 +170,6 @@ func (s *ProxyService) Delete(ctx context.Context, id int64) error { func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProxyNotFound - } return fmt.Errorf("get proxy: %w", err) } @@ -197,9 +184,6 @@ func (s *ProxyService) TestConnection(ctx context.Context, id int64) error { func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) { proxy, err := s.proxyRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return "", ErrProxyNotFound - } return "", fmt.Errorf("get proxy: %w", err) } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 578fb3dd..591d1555 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -9,19 +9,18 @@ import ( "strings" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/redis/go-redis/v9" - "gorm.io/gorm" ) var ( - ErrRedeemCodeNotFound = errors.New("redeem code not found") - ErrRedeemCodeUsed = errors.New("redeem code already used") - ErrRedeemCodeInvalid = errors.New("invalid redeem code") - ErrInsufficientBalance = errors.New("insufficient balance") - ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later") - ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again") + ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") + ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") + ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") + ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") + ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") ) const ( @@ -226,7 +225,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 查找兑换码 redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, ErrRedeemCodeNotFound) { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeNotFound } @@ -241,15 +240,12 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 验证兑换码类型的前置条件 if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil { - return nil, errors.New("invalid subscription redeem code: missing group_id") + return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id") } // 获取用户信息 user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } _ = user // 使用变量避免未使用错误 @@ -257,8 +253,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 【关键】先标记兑换码为已使用,确保并发安全 // 利用数据库乐观锁(WHERE status = 'unused')保证原子性 if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - // 兑换码已被其他请求使用 + if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) { return nil, ErrRedeemCodeUsed } return nil, fmt.Errorf("mark code as used: %w", err) @@ -328,9 +323,6 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrRedeemCodeNotFound - } return nil, fmt.Errorf("get redeem code: %w", err) } return code, nil @@ -340,9 +332,6 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) { redeemCode, err := s.redeemRepo.GetByCode(ctx, code) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrRedeemCodeNotFound - } return nil, fmt.Errorf("get redeem code: %w", err) } return redeemCode, nil @@ -362,15 +351,12 @@ func (s *RedeemService) Delete(ctx context.Context, id int64) error { // 检查兑换码是否存在 code, err := s.redeemRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrRedeemCodeNotFound - } return fmt.Errorf("get redeem code: %w", err) } // 不允许删除已使用的兑换码 if code.IsUsed() { - return errors.New("cannot delete used redeem code") + return infraerrors.Conflict("REDEEM_CODE_DELETE_USED", "cannot delete used redeem code") } if err := s.redeemRepo.Delete(ctx, id); err != nil { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index fcbb4035..cb38203c 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,13 +9,13 @@ import ( "strconv" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" - - "gorm.io/gorm" ) var ( - ErrRegistrationDisabled = errors.New("registration is currently disabled") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") ) type SettingRepository interface { @@ -187,7 +187,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 已有设置,不需要初始化 return nil } - if !errors.Is(err, gorm.ErrRecordNotFound) { + if !errors.Is(err, ErrSettingNotFound) { return fmt.Errorf("check existing settings: %w", err) } @@ -302,7 +302,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) { key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, ErrSettingNotFound) { return "", false, nil } return "", false, err @@ -326,7 +326,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) { key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, ErrSettingNotFound) { return "", nil // 未配置,返回空字符串 } return "", err // 数据库错误 diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 8d7a1b3b..f1ff6a2d 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -2,24 +2,24 @@ package service import ( "context" - "errors" "fmt" "log" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) var ( - ErrSubscriptionNotFound = errors.New("subscription not found") - ErrSubscriptionExpired = errors.New("subscription has expired") - ErrSubscriptionSuspended = errors.New("subscription is suspended") - ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group") - ErrGroupNotSubscriptionType = errors.New("group is not a subscription type") - ErrDailyLimitExceeded = errors.New("daily usage limit exceeded") - ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded") - ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded") + ErrSubscriptionNotFound = infraerrors.NotFound("SUBSCRIPTION_NOT_FOUND", "subscription not found") + ErrSubscriptionExpired = infraerrors.Forbidden("SUBSCRIPTION_EXPIRED", "subscription has expired") + ErrSubscriptionSuspended = infraerrors.Forbidden("SUBSCRIPTION_SUSPENDED", "subscription is suspended") + ErrSubscriptionAlreadyExists = infraerrors.Conflict("SUBSCRIPTION_ALREADY_EXISTS", "subscription already exists for this user and group") + ErrGroupNotSubscriptionType = infraerrors.BadRequest("GROUP_NOT_SUBSCRIPTION_TYPE", "group is not a subscription type") + ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded") + ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") + ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") ) // SubscriptionService 订阅服务 diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index 318c7d32..2a68c11b 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -2,14 +2,15 @@ package service import ( "context" - "errors" "fmt" "log" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" ) var ( - ErrTurnstileVerificationFailed = errors.New("turnstile verification failed") - ErrTurnstileNotConfigured = errors.New("turnstile not configured") + ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed") + ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured") ) // TurnstileVerifier 验证 Turnstile token 的接口 diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 33e41ea3..c574981d 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -2,18 +2,17 @@ package service import ( "context" - "errors" "fmt" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" - "gorm.io/gorm" ) var ( - ErrUsageLogNotFound = errors.New("usage log not found") + ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found") ) // CreateUsageLogRequest 创建使用日志请求 @@ -71,9 +70,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* // 验证用户存在 _, err := s.userRepo.GetByID(ctx, req.UserID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } @@ -119,9 +115,6 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUsageLogNotFound - } return nil, fmt.Errorf("get usage log: %w", err) } return log, nil diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 14830b57..5c314382 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,19 +2,18 @@ package service import ( "context" - "errors" "fmt" + infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" ) var ( - ErrUserNotFound = errors.New("user not found") - ErrPasswordIncorrect = errors.New("current password is incorrect") - ErrInsufficientPerms = errors.New("insufficient permissions") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ) type UserRepository interface { @@ -65,9 +64,6 @@ func NewUserService(userRepo UserRepository) *UserService { func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } return user, nil @@ -77,9 +73,6 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } @@ -119,9 +112,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrUserNotFound - } return fmt.Errorf("get user: %w", err) } @@ -149,9 +139,6 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrUserNotFound - } return nil, fmt.Errorf("get user: %w", err) } return user, nil @@ -178,9 +165,6 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrUserNotFound - } return fmt.Errorf("get user: %w", err) }