diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 78b71431..af1e7d91 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -3,7 +3,7 @@ package admin import ( "strconv" "strings" - "time" + "sync" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" ) // OAuthHandler handles OAuth-related operations for accounts @@ -1000,47 +1001,33 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) { return } - account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + ctx := c.Request.Context() + account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { response.NotFound(c, "Account not found") return } - if account.Credentials == nil || account.Credentials["oauth_type"] != "google_one" { - response.BadRequest(c, "Account is not a google_one OAuth account") + if account.Platform != service.PlatformGemini || account.Type != service.AccountTypeOAuth { + response.BadRequest(c, "Only Gemini OAuth accounts support tier refresh") return } - accessToken, ok := account.Credentials["access_token"].(string) - if !ok || accessToken == "" { - response.BadRequest(c, "Missing access_token in credentials") + oauthType, _ := account.Credentials["oauth_type"].(string) + if oauthType != "google_one" { + response.BadRequest(c, "Only google_one OAuth accounts support tier refresh") return } - var proxyURL string - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(c.Request.Context(), accessToken, proxyURL) + tierID, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(ctx, account) if err != nil { response.ErrorFrom(c, err) return } - if account.Extra == nil { - account.Extra = make(map[string]any) - } - if storageInfo != nil { - account.Extra["drive_storage_limit"] = storageInfo.Limit - account.Extra["drive_storage_usage"] = storageInfo.Usage - account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339) - } - account.Credentials["tier_id"] = tierID - - _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ - Credentials: account.Credentials, - Extra: account.Extra, + _, updateErr := h.adminService.UpdateAccount(ctx, accountID, &service.UpdateAccountInput{ + Credentials: creds, + Extra: extra, }) if updateErr != nil { response.ErrorFrom(c, updateErr) @@ -1049,9 +1036,10 @@ func (h *AccountHandler) RefreshTier(c *gin.Context) { response.Success(c, gin.H{ "tier_id": tierID, - "drive_storage_limit": account.Extra["drive_storage_limit"], - "drive_storage_usage": account.Extra["drive_storage_usage"], - "updated_at": account.Extra["drive_tier_updated_at"], + "storage_info": extra, + "drive_storage_limit": extra["drive_storage_limit"], + "drive_storage_usage": extra["drive_storage_usage"], + "updated_at": extra["drive_tier_updated_at"], }) } @@ -1069,7 +1057,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { } ctx := c.Request.Context() - var accounts []service.Account + accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") @@ -1077,84 +1065,87 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { response.ErrorFrom(c, err) return } - for _, acc := range allAccounts { - if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" { + for i := range allAccounts { + acc := &allAccounts[i] + oauthType, _ := acc.Credentials["oauth_type"].(string) + if oauthType == "google_one" { accounts = append(accounts, acc) } } } else { - for _, id := range req.AccountIDs { - acc, err := h.adminService.GetAccount(ctx, id) - if err != nil { + fetched, err := h.adminService.GetAccountsByIDs(ctx, req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + for _, acc := range fetched { + if acc == nil { continue } - if acc.Credentials != nil && acc.Credentials["oauth_type"] == "google_one" { - accounts = append(accounts, *acc) + if acc.Platform != service.PlatformGemini || acc.Type != service.AccountTypeOAuth { + continue } + oauthType, _ := acc.Credentials["oauth_type"].(string) + if oauthType != "google_one" { + continue + } + accounts = append(accounts, acc) } } - total := len(accounts) - success := 0 - failed := 0 - errors := []gin.H{} + const maxConcurrency = 10 + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(maxConcurrency) + + var mu sync.Mutex + results := gin.H{ + "total": len(accounts), + "success": 0, + "failed": 0, + "errors": []gin.H{}, + } for _, account := range accounts { - accessToken, ok := account.Credentials["access_token"].(string) - if !ok || accessToken == "" { - failed++ - errors = append(errors, gin.H{ - "account_id": account.ID, - "error": "missing access_token", + acc := account // 闭包捕获 + g.Go(func() error { + _, extra, creds, err := h.geminiOAuthService.RefreshAccountGoogleOneTier(gctx, acc) + if err != nil { + mu.Lock() + results["failed"] = results["failed"].(int) + 1 + results["errors"] = append(results["errors"].([]gin.H), gin.H{ + "account_id": acc.ID, + "error": err.Error(), + }) + mu.Unlock() + return nil + } + + _, updateErr := h.adminService.UpdateAccount(gctx, acc.ID, &service.UpdateAccountInput{ + Credentials: creds, + Extra: extra, }) - continue - } - var proxyURL string - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } + mu.Lock() + if updateErr != nil { + results["failed"] = results["failed"].(int) + 1 + results["errors"] = append(results["errors"].([]gin.H), gin.H{ + "account_id": acc.ID, + "error": updateErr.Error(), + }) + } else { + results["success"] = results["success"].(int) + 1 + } + mu.Unlock() - tierID, storageInfo, err := h.geminiOAuthService.FetchGoogleOneTier(ctx, accessToken, proxyURL) - if err != nil { - failed++ - errors = append(errors, gin.H{ - "account_id": account.ID, - "error": err.Error(), - }) - continue - } - - if account.Extra == nil { - account.Extra = make(map[string]any) - } - if storageInfo != nil { - account.Extra["drive_storage_limit"] = storageInfo.Limit - account.Extra["drive_storage_usage"] = storageInfo.Usage - account.Extra["drive_tier_updated_at"] = timezone.Now().Format(time.RFC3339) - } - account.Credentials["tier_id"] = tierID - - _, updateErr := h.adminService.UpdateAccount(ctx, account.ID, &service.UpdateAccountInput{ - Credentials: account.Credentials, - Extra: account.Extra, + return nil }) - if updateErr != nil { - failed++ - errors = append(errors, gin.H{ - "account_id": account.ID, - "error": updateErr.Error(), - }) - continue - } - - success++ } - response.Success(c, gin.H{ - "total": total, - "success": success, - "failed": failed, - "errors": errors, - }) + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, results) } diff --git a/backend/internal/pkg/geminicli/drive_client.go b/backend/internal/pkg/geminicli/drive_client.go index 79d6835f..77e2c476 100644 --- a/backend/internal/pkg/geminicli/drive_client.go +++ b/backend/internal/pkg/geminicli/drive_client.go @@ -5,7 +5,9 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net/http" + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -49,13 +51,38 @@ func (c *driveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL return nil, fmt.Errorf("failed to create HTTP client: %w", err) } - // Retry logic with exponential backoff for rate limits + sleepWithContext := func(d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + + // Retry logic with exponential backoff (+ jitter) for rate limits and transient failures var resp *http.Response maxRetries := 3 + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for attempt := 0; attempt < maxRetries; attempt++ { + if ctx.Err() != nil { + return nil, fmt.Errorf("request cancelled: %w", ctx.Err()) + } + resp, err = client.Do(req) if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) + // Network error retry + if attempt < maxRetries-1 { + backoff := time.Duration(1< 100*1024*1024*1024*1024 { // > 100TB + if storageBytes > StorageTierUnlimited { return TierGoogleOneUnlimited } - - // AI Premium (2TB+) - if storageBytes >= 2*1024*1024*1024*1024 { // >= 2TB + if storageBytes >= StorageTierAIPremium { return TierAIPremium } - - // Google One Standard (200GB) - if storageBytes >= 200*1024*1024*1024 { // >= 200GB + if storageBytes >= StorageTierStandard { return TierGoogleOneStandard } - - // Google One Basic (100GB) - if storageBytes >= 100*1024*1024*1024 { // >= 100GB + if storageBytes >= StorageTierBasic { return TierGoogleOneBasic } - - // Free (15GB) - if storageBytes >= 15*1024*1024*1024 { // >= 15GB + if storageBytes >= StorageTierFree { return TierFree } - return TierGoogleOneUnknown } @@ -270,6 +271,60 @@ func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken return tierID, storageInfo, nil } +// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier +func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( + ctx context.Context, + account *Account, +) (tierID string, extra map[string]any, credentials map[string]any, err error) { + if account == nil { + return "", nil, nil, fmt.Errorf("account is nil") + } + + // 验证账号类型 + oauthType, ok := account.Credentials["oauth_type"].(string) + if !ok || oauthType != "google_one" { + return "", nil, nil, fmt.Errorf("not a google_one OAuth account") + } + + // 获取 access_token + accessToken, ok := account.Credentials["access_token"].(string) + if !ok || accessToken == "" { + return "", nil, nil, fmt.Errorf("missing access_token") + } + + // 获取 proxy URL + var proxyURL string + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 调用 Drive API + tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL) + if err != nil { + return "", nil, nil, err + } + + // 构建 extra 数据(保留原有 extra 字段) + extra = make(map[string]any) + for k, v := range account.Extra { + extra[k] = v + } + if storageInfo != nil { + extra["drive_storage_limit"] = storageInfo.Limit + extra["drive_storage_usage"] = storageInfo.Usage + extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339) + } + + // 构建 credentials 数据 + credentials = make(map[string]any) + for k, v := range account.Credentials { + credentials[k] = v + } + credentials["tier_id"] = tierID + + return tierID, extra, credentials, nil +} + func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { session, ok := s.sessionStore.Get(input.SessionID) if !ok { diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go new file mode 100644 index 00000000..393812c2 --- /dev/null +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -0,0 +1,52 @@ +package service + +import "testing" + +func TestInferGoogleOneTier(t *testing.T) { + tests := []struct { + name string + storageBytes int64 + expectedTier string + }{ + {"Negative storage", -1, TierGoogleOneUnknown}, + {"Zero storage", 0, TierGoogleOneUnknown}, + + // Free tier boundary (15GB) + {"Below free tier", 10 * GB, TierGoogleOneUnknown}, + {"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown}, + {"Free tier (15GB)", StorageTierFree, TierFree}, + + // Basic tier boundary (100GB) + {"Between free and basic", 50 * GB, TierFree}, + {"Just below basic tier", StorageTierBasic - 1, TierFree}, + {"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic}, + + // Standard tier boundary (200GB) + {"Between basic and standard", 150 * GB, TierGoogleOneBasic}, + {"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic}, + {"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard}, + + // AI Premium tier boundary (2TB) + {"Between standard and premium", 1 * TB, TierGoogleOneStandard}, + {"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard}, + {"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium}, + + // Unlimited tier boundary (> 100TB) + {"Between premium and unlimited", 50 * TB, TierAIPremium}, + {"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium}, + {"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited}, + {"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited}, + {"Very large storage", 1000 * TB, TierGoogleOneUnlimited}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := inferGoogleOneTier(tt.storageBytes) + if result != tt.expectedTier { + t.Errorf("inferGoogleOneTier(%d) = %s, want %s", + tt.storageBytes, result, tt.expectedTier) + } + }) + } +} + diff --git a/backend/migrations/017_add_gemini_tier_id.sql b/backend/migrations/024_add_gemini_tier_id.sql similarity index 94% rename from backend/migrations/017_add_gemini_tier_id.sql rename to backend/migrations/024_add_gemini_tier_id.sql index 0388a412..d9ac7afe 100644 --- a/backend/migrations/017_add_gemini_tier_id.sql +++ b/backend/migrations/024_add_gemini_tier_id.sql @@ -26,5 +26,5 @@ UPDATE accounts SET credentials = credentials - 'tier_id' WHERE platform = 'gemini' AND type = 'oauth' - AND credentials->>'oauth_type' = 'code_assist'; + AND credentials ? 'tier_id'; -- +goose StatementEnd