diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 6b0c6370..f4d809ba 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 2976cf13..2fba69cb 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -715,7 +715,7 @@ func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.mode", "release") viper.SetDefault("server.frontend_url", "") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 @@ -751,7 +751,7 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", false) + viper.SetDefault("security.response_headers.enabled", true) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.csp.enabled", true) @@ -789,9 +789,9 @@ func setDefaults() { viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "disable") - viper.SetDefault("database.max_open_conns", 50) - viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_idle_time_minutes", 5) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 4bde837f..a645d343 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if !cfg.Security.URLAllowlist.AllowPrivateHosts { t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") } - if cfg.Security.ResponseHeaders.Enabled { - t.Fatalf("ResponseHeaders.Enabled = true, want false") + if !cfg.Security.ResponseHeaders.Enabled { + t.Fatalf("ResponseHeaders.Enabled = false, want true") + } +} + +func TestLoadDefaultServerMode(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Mode != "release" { + t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release") + } +} + +func TestLoadDefaultDatabaseSSLMode(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Database.SSLMode != "prefer" { + t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer") } } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 6d42f726..e8968c6d 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -3,6 +3,7 @@ package admin import ( "errors" + "fmt" "strconv" "strings" "sync" @@ -738,57 +739,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { } ctx := c.Request.Context() - success := 0 - failed := 0 - results := []gin.H{} + // 阶段一:预验证所有账号存在,收集 credentials + type accountUpdate struct { + ID int64 + Credentials map[string]any + } + updates := make([]accountUpdate, 0, len(req.AccountIDs)) for _, accountID := range req.AccountIDs { - // Get account account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": "Account not found", - }) - continue + response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID)) + return } - - // Update credentials field if account.Credentials == nil { account.Credentials = make(map[string]any) } - account.Credentials[req.Field] = req.Value + updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials}) + } - // Update account + // 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败) + for _, u := range updates { updateInput := &service.UpdateAccountInput{ - Credentials: account.Credentials, + Credentials: u.Credentials, } - - _, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) - if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": err.Error(), - }) - continue + if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil { + response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err)) + return } - - success++ - results = append(results, gin.H{ - "account_id": accountID, - "success": true, - }) } response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + "success": len(updates), + "failed": 0, }) } diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go new file mode 100644 index 00000000..4c47fadb --- /dev/null +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -0,0 +1,200 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。 +type failingAdminService struct { + *stubAdminService + failOnAccountID int64 + updateCallCount atomic.Int64 +} + +func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + f.updateCallCount.Add(1) + if id == f.failOnAccountID { + return nil, errors.New("database error") + } + return f.stubAdminService.UpdateAccount(ctx, id, input) +} + +func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) + return router, handler +} + +func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test-uuid", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200") + require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") +} + +func TestBatchUpdateCredentials_FailFast(t *testing.T) { + // 让第 2 个账号(ID=2)更新时失败 + svc := &failingAdminService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 2, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "org_uuid", + Value: "test-org", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500") + // 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用 + require.Equal(t, int64(2), svc.updateCallCount.Load(), + "fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)") +} + +func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { + // GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub + svc := &getAccountFailingService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 1, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404") +} + +// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。 +type getAccountFailingService struct { + *stubAdminService + failOnAccountID int64 +} + +func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if id == f.failOnAccountID { + return nil, errors.New("not found") + } + return f.stubAdminService.GetAccount(ctx, id) +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // intercept_warmup_requests 传入非 bool 类型(string),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": "not-a-bool", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "intercept_warmup_requests 传入非 bool 值应返回 400") +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": true, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "intercept_warmup_requests 传入合法 bool 值应返回 200") +} + +func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入非 string 类型(number),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": 12345, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "account_uuid 传入非 string 值应返回 400") +} + +func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入 null(设置为空),应正常通过 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": nil, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "account_uuid 传入 null 应返回 200") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 18365186..fab66c04 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return @@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return diff --git a/backend/internal/handler/admin/search_truncate_test.go b/backend/internal/handler/admin/search_truncate_test.go new file mode 100644 index 00000000..ffd60e2a --- /dev/null +++ b/backend/internal/handler/admin/search_truncate_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑 +func truncateSearchByRune(search string, maxRunes int) string { + if runes := []rune(search); len(runes) > maxRunes { + return string(runes[:maxRunes]) + } + return search +} + +func TestTruncateSearchByRune(t *testing.T) { + tests := []struct { + name string + input string + maxRunes int + wantLen int // 期望的 rune 长度 + }{ + { + name: "纯中文超长", + input: string(make([]rune, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "纯 ASCII 超长", + input: string(make([]byte, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "空字符串", + input: "", + maxRunes: 100, + wantLen: 0, + }, + { + name: "恰好 100 个字符", + input: string(make([]rune, 100)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "不足 100 字符不截断", + input: "hello世界", + maxRunes: 100, + wantLen: 7, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := truncateSearchByRune(tc.input, tc.maxRunes) + require.Equal(t, tc.wantLen, len([]rune(result))) + }) + } +} + +func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) { + // 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8 + input := "" + for i := 0; i < 101; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + require.Equal(t, 100, len([]rune(result))) + // 验证截断结果是有效的 UTF-8(每个中文字符 3 字节) + require.Equal(t, 300, len(result)) +} + +func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) { + // 50 个 ASCII + 51 个中文 = 101 个 rune + input := "" + for i := 0; i < 50; i++ { + input += "a" + } + for i := 0; i < 51; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + runes := []rune(result) + require.Equal(t, 100, len(runes)) + // 前 50 个应该是 'a',后 50 个应该是 '中' + require.Equal(t, 'a', runes[0]) + require.Equal(t, 'a', runes[49]) + require.Equal(t, '中', runes[50]) + require.Equal(t, '中', runes[99]) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1c772e7d..0427e77e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) { search := c.Query("search") // 标准化和验证 search 参数 search = strings.TrimSpace(search) - if len(search) > 100 { - search = search[:100] + if runes := []rune(search); len(runes) > 100 { + search = string(runes[:100]) } filters := service.UserListFilters{ diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index beaddbca..e4f96710 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -210,7 +210,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -258,12 +259,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - // Ensure the wait counter is decremented if we exit before acquiring the slot. - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -275,14 +276,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } @@ -367,7 +366,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -415,11 +415,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -431,13 +432,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } @@ -930,7 +930,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + log.Printf("[Gateway] SelectAccountForModel failed: %v", err) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } setOpsSelectedAccount(c, account.ID) @@ -1143,7 +1144,8 @@ func billingErrorDetails(err error) (status int, code, message string) { } msg := pkgerrors.Message(err) if msg == "" { - msg = err.Error() + log.Printf("[Gateway] billing error details: %v", err) + msg = "Billing error" } return http.StatusForbidden, "billing_error", msg } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 13b3703e..8f512d8e 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -216,7 +216,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err != nil { log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -249,11 +250,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -265,13 +267,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) if err != nil { log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 129dbfa6..b8182dad 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { return } - stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) + stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index eb16f09d..1f58eb8e 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -1,6 +1,7 @@ package antigravity import ( + "crypto/rand" "encoding/json" "fmt" "log" @@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } -// generateRandomID 生成随机 ID +// generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" result := make([]byte, 12) - for i := range result { - result[i] = chars[i%len(chars)] + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + panic("crypto/rand unavailable: " + err.Error()) + } + for i, b := range randBytes { + result[i] = chars[int(b)%len(chars)] } return string(result) } diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go new file mode 100644 index 00000000..9731d906 --- /dev/null +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -0,0 +1,36 @@ +//go:build unit + +package antigravity + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGenerateRandomID_Uniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + id := generateRandomID() + require.Len(t, id, 12, "ID 长度应为 12") + _, dup := seen[id] + require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id) + seen[id] = struct{}{} + } +} + +func TestGenerateRandomID_Charset(t *testing.T) { + const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validSet := make(map[byte]struct{}, len(validChars)) + for i := 0; i < len(validChars); i++ { + validSet[validChars[i]] = struct{}{} + } + + for i := 0; i < 50; i++ { + id := generateRandomID() + for j := 0; j < len(id); j++ { + _, ok := validSet[id[j]] + require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id) + } + } +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 97109c0c..6ab2ff72 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -54,29 +54,34 @@ func normalizeIP(ip string) string { return ip } -// isPrivateIP 检查 IP 是否为私有地址。 -func isPrivateIP(ipStr string) bool { - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } +// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 +var privateNets []*net.IPNet - // 私有 IP 范围 - privateBlocks := []string{ +func init() { + for _, cidr := range []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", - } - - for _, block := range privateBlocks { - _, cidr, err := net.ParseCIDR(block) + } { + _, block, err := net.ParseCIDR(cidr) if err != nil { - continue + panic("invalid CIDR: " + cidr) } - if cidr.Contains(ip) { + privateNets = append(privateNets, block) + } +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + for _, block := range privateNets { + if block.Contains(ip) { return true } } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go new file mode 100644 index 00000000..c3c90c74 --- /dev/null +++ b/backend/internal/pkg/ip/ip_test.go @@ -0,0 +1,51 @@ +//go:build unit + +package ip + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // 私有 IPv4 + {"10.x 私有地址", "10.0.0.1", true}, + {"10.x 私有地址段末", "10.255.255.255", true}, + {"172.16.x 私有地址", "172.16.0.1", true}, + {"172.31.x 私有地址", "172.31.255.255", true}, + {"192.168.x 私有地址", "192.168.1.1", true}, + {"127.0.0.1 本地回环", "127.0.0.1", true}, + {"127.x 回环段", "127.255.255.255", true}, + + // 公网 IPv4 + {"8.8.8.8 公网 DNS", "8.8.8.8", false}, + {"1.1.1.1 公网", "1.1.1.1", false}, + {"172.15.255.255 非私有", "172.15.255.255", false}, + {"172.32.0.0 非私有", "172.32.0.0", false}, + {"11.0.0.1 公网", "11.0.0.1", false}, + + // IPv6 + {"::1 IPv6 回环", "::1", true}, + {"fc00:: IPv6 私有", "fc00::1", true}, + {"fd00:: IPv6 私有", "fd00::1", true}, + {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, + + // 无效输入 + {"空字符串", "", false}, + {"非法字符串", "not-an-ip", false}, + {"不完整 IP", "192.168", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isPrivateIP(tc.ip) + require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + }) + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 42510986..992f8b0a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st return nil, fmt.Errorf("apply TLS preset: %w", err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index c0cfd256..c86968b7 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } -// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值 func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { - // Use raw SQL for atomic increment to avoid race conditions - // First get current value - m, err := r.activeQuery(). - Where(apikey.IDEQ(id)). - Select(apikey.FieldQuotaUsed). - Only(ctx) + updated, err := r.client.APIKey.UpdateOneID(id). + Where(apikey.DeletedAtIsNil()). + AddQuotaUsed(amount). + Save(ctx) if err != nil { if dbent.IsNotFound(err) { return 0, service.ErrAPIKeyNotFound } return 0, err } - - newValue := m.QuotaUsed + amount - - // Update with new value - affected, err := r.client.APIKey.Update(). - Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). - SetQuotaUsed(newValue). - Save(ctx) - if err != nil { - return 0, err - } - if affected == 0 { - return 0, service.ErrAPIKeyNotFound - } - - return newValue, nil + return updated.QuotaUsed, nil } func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 879a0576..303d7126 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -4,11 +4,14 @@ package repository import ( "context" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group s.Require().NoError(s.repo.Create(s.ctx, k), "create api key") return k } + +// --- IncrementQuotaUsed --- + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() { + user := s.mustCreateUser("incr-basic@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil) + + newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5) + s.Require().NoError(err, "IncrementQuotaUsed") + s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5") + + newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsed second") + s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() { + _, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { + user := s.mustCreateUser("incr-deleted@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil) + + s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete") + + _, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") +} + +// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 +// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 +func TestIncrementQuotaUsed_Concurrent(t *testing.T) { + client := testEntClient(t) + repo := NewAPIKeyRepository(client).(*apiKeyRepository) + ctx := context.Background() + + // 创建测试用户和 API Key + u, err := client.User.Create(). + SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com"). + SetPasswordHash("hash"). + SetStatus(service.StatusActive). + SetRole(service.RoleUser). + Save(ctx) + require.NoError(t, err, "create user") + + k := &service.APIKey{ + UserID: u.ID, + Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano), + Name: "Concurrent", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, k), "create api key") + t.Cleanup(func() { + _ = client.APIKey.DeleteOneID(k.ID).Exec(ctx) + _ = client.User.DeleteOneID(u.ID).Exec(ctx) + }) + + // 10 个 goroutine 各递增 1.0,总计应为 10.0 + const goroutines = 10 + const increment = 1.0 + var wg sync.WaitGroup + errs := make([]error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment) + }(i) + } + wg.Wait() + + for i, e := range errs { + require.NoError(t, e, "goroutine %d failed", i) + } + + // 验证最终结果 + got, err := repo.GetByID(ctx, k.ID) + require.NoError(t, err, "GetByID") + require.Equal(t, float64(goroutines)*increment, got.QuotaUsed, + "并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed) +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index ac5803a1..50ea0da9 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "math/rand" "strconv" "time" @@ -16,8 +17,15 @@ const ( billingBalanceKeyPrefix = "billing:balance:" billingSubKeyPrefix = "billing:sub:" billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second ) +// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 +func jitteredTTL() time.Duration { + jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter + return billingCacheTTL + jitter +} + // billingBalanceKey generates the Redis key for user balance cache. func billingBalanceKey(userID int64) string { return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) @@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { key := billingBalanceKey(userID) - return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() + return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + return err } return nil } @@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID pipe := c.rdb.Pipeline() pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) + pipe.Expire(ctx, key, jitteredTTL()) _, err := pipe.Exec(ctx) return err } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { key := billingSubKey(userID, groupID) - _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + return err } return nil } diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 2f7c69a7..4b7377b1 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } +// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { + tests := []struct { + name string + fn func(ctx context.Context, cache service.BillingCache) + expectErr bool + }{ + { + name: "key_not_exists_returns_nil", + fn: func(ctx context.Context, cache service.BillingCache) { + // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + err := cache.DeductUserBalance(ctx, 99999, 1.0) + require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + }, + }, + { + name: "existing_key_deducts_successfully", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0)) + err := cache.DeductUserBalance(ctx, 200, 10.0) + require.NoError(s.T(), err, "DeductUserBalance should succeed") + + bal, err := cache.GetUserBalance(ctx, 200) + require.NoError(s.T(), err) + require.Equal(s.T(), 40.0, bal, "余额应为 40.0") + }, + }, + { + name: "cancelled_context_propagates_error", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // 立即取消 + + err := cache.DeductUserBalance(cancelCtx, 201, 10.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + tt.fn(ctx, cache) + }) + } +} + +// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() { + s.Run("key_not_exists_returns_nil", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0) + require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil") + }) + + s.Run("cancelled_context_propagates_error", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + data := &service.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }) +} + func TestBillingCacheSuite(t *testing.T) { suite.Run(t, new(BillingCacheSuite)) } diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go index 7d3fd19d..2de1da87 100644 --- a/backend/internal/repository/billing_cache_test.go +++ b/backend/internal/repository/billing_cache_test.go @@ -5,6 +5,7 @@ package repository import ( "math" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { }) } } + +func TestJitteredTTL(t *testing.T) { + const ( + minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s + maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s + ) + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl) + require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl) + } +} + +func TestJitteredTTL_HasVariation(t *testing.T) { + // 多次调用应该产生不同的值(验证抖动存在) + seen := make(map[time.Duration]struct{}, 50) + for i := 0; i < 50; i++ { + seen[jitteredTTL()] = struct{}{} + } + // 50 次调用中应该至少有 2 个不同的值 + require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值") +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index d8cec491..412a8164 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination q = q.Where(group.IsExclusiveEQ(*isExclusive)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 98b422e0..95ce687a 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina q = q.Where(promocode.CodeContainsFold(search)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo q := r.client.PromoCodeUsage.Query(). Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 2db1764f..d51669aa 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -24,6 +24,22 @@ import ( const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" +// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL +var dateFormatWhitelist = map[string]string{ + "hour": "YYYY-MM-DD HH24:00", + "day": "YYYY-MM-DD", + "week": "IYYY-IW", + "month": "YYYY-MM", +} + +// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值 +func safeDateFormat(granularity string) string { + if f, ok := dateFormatWhitelist[granularity]; ok { + return f + } + return "YYYY-MM-DD" +} + type usageLogRepository struct { client *dbent.Client sql sqlExecutor @@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, } func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) return logs, nil, err } @@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string { } func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_keys AS ( @@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, // GetUserUsageTrend returns usage trend data grouped by user and date func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_users AS ( @@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats -// GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) if len(userIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range userIDs { result[id] = &BatchUserUsageStats{UserID: id} } @@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs query := ` SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) if err != nil { return nil, err } @@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) if len(apiKeyIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range apiKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } @@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe query := ` SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE api_key_id = ANY($1) + WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) if err != nil { return nil, err } @@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe // GetUsageTrendWithFilters returns usage trend data with optional filters func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index eb220f22..8cb3aab1 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().Len(stats, 2) s.Require().NotNil(stats[user1.ID]) @@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { } func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().Len(stats, 2) } func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go new file mode 100644 index 00000000..d0e14ffd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -0,0 +1,41 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeDateFormat(t *testing.T) { + tests := []struct { + name string + granularity string + expected string + }{ + // 合法值 + {"hour", "hour", "YYYY-MM-DD HH24:00"}, + {"day", "day", "YYYY-MM-DD"}, + {"week", "week", "IYYY-IW"}, + {"month", "month", "YYYY-MM"}, + + // 非法值回退到默认 + {"空字符串", "", "YYYY-MM-DD"}, + {"未知粒度 year", "year", "YYYY-MM-DD"}, + {"未知粒度 minute", "minute", "YYYY-MM-DD"}, + + // 恶意字符串 + {"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"}, + {"带引号", "day'", "YYYY-MM-DD"}, + {"带括号", "day)", "YYYY-MM-DD"}, + {"Unicode", "日", "YYYY-MM-DD"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := safeDateFormat(tc.granularity) + require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity) + }) + } +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 5e0aa97d..c34b9d15 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -592,7 +592,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() @@ -1598,11 +1598,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 3ec6154d..7b6d4ce8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { return &clone, nil }, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) router := gin.New() router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 7d82f183..b54a0b0e 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") // 处理预检请求 if c.Request.Method == http.MethodOptions { diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 304c5781..7698223e 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -36,8 +36,8 @@ type UsageLogRepository interface { GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index aa24c60a..49197df8 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1687,7 +1687,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Usage: ClaudeUsage{}, Model: originalModel, Stream: false, - Duration: time.Since(time.Now()), + Duration: time.Since(startTime), FirstTokenMs: nil, }, nil default: diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index cd11923e..32704a94 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } -func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { - stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) +func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { + stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch user usage stats: %w", err) } return stats, nil } -func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 5594e53f..f21a2855 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star } // GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys. -func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1bfb392e..510e734e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -3,6 +3,8 @@ package service import ( "context" "fmt" + "log" + "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -62,13 +64,15 @@ type ChangePasswordRequest struct { type UserService struct { userRepo UserRepository authCacheInvalidator APIKeyAuthCacheInvalidator + billingCache BillingCache } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService { return &UserService{ userRepo: userRepo, authCacheInvalidator: authCacheInvalidator, + billingCache: billingCache, } } @@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } + if s.billingCache != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { + log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + } + }() + } return nil } diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go new file mode 100644 index 00000000..0f355d70 --- /dev/null +++ b/backend/internal/service/user_service_test.go @@ -0,0 +1,186 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --- mock: UserRepository --- + +type mockUserRepo struct { + updateBalanceErr error + updateBalanceFn func(ctx context.Context, id int64, amount float64) error +} + +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Update(context.Context, *User) error { return nil } +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + if m.updateBalanceFn != nil { + return m.updateBalanceFn(ctx, id, amount) + } + return m.updateBalanceErr +} +func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil } +func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } + +// --- mock: APIKeyAuthCacheInvalidator --- + +type mockAuthCacheInvalidator struct { + invalidatedUserIDs []int64 + mu sync.Mutex +} + +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) +} + +// --- mock: BillingCache --- + +type mockBillingCache struct { + invalidateErr error + invalidateCallCount atomic.Int64 + invalidatedUserIDs []int64 + mu sync.Mutex +} + +func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } +func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error { + m.invalidateCallCount.Add(1) + m.mu.Lock() + defer m.mu.Unlock() + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) + return m.invalidateErr +} +func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) { + return nil, nil +} +func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error { + return nil +} +func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error { + return nil +} +func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error { + return nil +} + +// --- 测试 --- + +func TestUpdateBalance_Success(t *testing.T) { + repo := &mockUserRepo{} + cache := &mockBillingCache{} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 42, 100.0) + require.NoError(t, err) + + // 等待异步 goroutine 完成 + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance") + + cache.mu.Lock() + defer cache.mu.Unlock() + require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存") +} + +func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) { + repo := &mockUserRepo{} + svc := NewUserService(repo, nil, nil) // billingCache = nil + + err := svc.UpdateBalance(context.Background(), 1, 50.0) + require.NoError(t, err, "billingCache 为 nil 时不应 panic") +} + +func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) { + repo := &mockUserRepo{} + cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 99, 200.0) + require.NoError(t, err, "缓存失效失败不应影响主流程返回值") + + // 等待异步 goroutine 完成(即使失败也应调用) + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance") +} + +func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) { + repo := &mockUserRepo{updateBalanceErr: errors.New("database error")} + cache := &mockBillingCache{} + svc := NewUserService(repo, nil, cache) + + err := svc.UpdateBalance(context.Background(), 1, 100.0) + require.Error(t, err, "repo 失败时应返回错误") + require.Contains(t, err.Error(), "update balance") + + // repo 失败时不应触发缓存失效 + time.Sleep(100 * time.Millisecond) + require.Equal(t, int64(0), cache.invalidateCallCount.Load(), + "repo 失败时不应调用 InvalidateUserBalance") +} + +func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) { + repo := &mockUserRepo{} + auth := &mockAuthCacheInvalidator{} + cache := &mockBillingCache{} + svc := NewUserService(repo, auth, cache) + + err := svc.UpdateBalance(context.Background(), 77, 300.0) + require.NoError(t, err) + + // 验证 auth cache 同步失效 + auth.mu.Lock() + require.Equal(t, []int64{77}, auth.invalidatedUserIDs) + auth.mu.Unlock() + + // 验证 billing cache 异步失效 + require.Eventually(t, func() bool { + return cache.invalidateCallCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) +} + +func TestNewUserService_FieldsAssignment(t *testing.T) { + repo := &mockUserRepo{} + auth := &mockAuthCacheInvalidator{} + cache := &mockBillingCache{} + + svc := NewUserService(repo, auth, cache) + require.NotNil(t, svc) + require.Equal(t, repo, svc.userRepo) + require.Equal(t, auth, svc.authCacheInvalidator) + require.Equal(t, cache, svc.billingCache) +} diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index dbf0703d..8b15b54f 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -112,9 +112,9 @@ security: # 白名单禁用时是否允许 http:// URL(默认: false,要求 https) allow_insecure_http: true response_headers: - # Enable configurable response header filtering (disable to use default allowlist) - # 启用可配置的响应头过滤(禁用则使用默认白名单) - enabled: false + # Enable configurable response header filtering (default: true) + # 启用可配置的响应头过滤(默认启用,过滤上游敏感响应头) + enabled: true # Extra allowed response headers from upstream # 额外允许的上游响应头 additional_allowed: [] @@ -390,15 +390,16 @@ database: # Database name # 数据库名称 dbname: "sub2api" - # SSL mode: disable, require, verify-ca, verify-full - # SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) - sslmode: "disable" - # Max open connections + # SSL mode: disable, prefer, require, verify-ca, verify-full + # SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证) + # 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文 + sslmode: "prefer" + # Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整) # 最大打开连接数 - max_open_conns: 50 - # Max idle connections + max_open_conns: 256 + # Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销) # 最大空闲连接数 - max_idle_conns: 10 + max_idle_conns: 128 # Connection max lifetime (minutes) # 连接最大存活时间(分钟) conn_max_lifetime_minutes: 30 @@ -426,9 +427,9 @@ redis: # Connection pool size (max concurrent connections) # 连接池大小(最大并发连接数) pool_size: 1024 - # Minimum number of idle connections + # Minimum number of idle connections (高并发场景建议 128+,保持足够热连接) # 最小空闲连接数 - min_idle_conns: 10 + min_idle_conns: 128 # Enable TLS/SSL connection # 是否启用 TLS/SSL 连接 enable_tls: false