diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 2fe3f468..cd44061b 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -73,6 +73,10 @@ func provideCleanup( name string fn func() error }{ + {"TokenRefreshService", func() error { + services.TokenRefresh.Stop() + return nil + }}, {"PricingService", func() error { services.Pricing.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 347e831f..f14fea15 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -106,6 +106,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream) concurrencyCache := repository.NewConcurrencyCache(client) concurrencyService := service.NewConcurrencyService(concurrencyCache) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler) @@ -138,6 +139,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { Concurrency: concurrencyService, Identity: identityService, Update: updateService, + TokenRefresh: tokenRefreshService, } repositories := &repository.Repositories{ User: userRepository, @@ -187,6 +189,10 @@ func provideCleanup( name string fn func() error }{ + {"TokenRefreshService", func() error { + services.TokenRefresh.Stop() + return nil + }}, {"PricingService", func() error { services.Pricing.Stop() return nil diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d8836e92..34ecbfb5 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -8,15 +8,30 @@ import ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" +} + +// TokenRefreshConfig OAuth token自动刷新配置 +type TokenRefreshConfig struct { + // 是否启用自动刷新 + Enabled bool `mapstructure:"enabled"` + // 检查间隔(分钟) + CheckIntervalMinutes int `mapstructure:"check_interval_minutes"` + // 提前刷新时间(小时),在token过期前多久开始刷新 + RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"` + // 最大重试次数 + MaxRetries int `mapstructure:"max_retries"` + // 重试退避基础时间(秒) + RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` } type PricingConfig struct { @@ -192,6 +207,13 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + + // TokenRefresh + viper.SetDefault("token_refresh.enabled", true) + viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 + viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新 + viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 + viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 } func (c *Config) Validate() error { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9861db60..0e208ac3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -13,7 +13,6 @@ import ( "log" "net/http" "regexp" - "strconv" "strings" "time" @@ -34,7 +33,6 @@ const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL - tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token ) // allowedHeaders 白名单headers(参考CRS项目) @@ -358,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) { accessToken := account.GetCredential("access_token") - expiresAtStr := account.GetCredential("expires_at") - - // 检查是否需要刷新 - needRefresh := false - if expiresAtStr != "" { - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt { - needRefresh = true - } + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") } - - if needRefresh || accessToken == "" { - tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", "", fmt.Errorf("refresh token failed: %w", err) - } - - // 更新账号凭证 - account.Credentials["access_token"] = tokenInfo.AccessToken - account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - account.Credentials["refresh_token"] = tokenInfo.RefreshToken - } - - if err := s.accountRepo.Update(ctx, account); err != nil { - log.Printf("Failed to update account credentials: %v", err) - } - - return tokenInfo.AccessToken, "oauth", nil - } - + // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token return accessToken, "oauth", nil } @@ -442,25 +413,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m } defer resp.Body.Close() - // 处理401错误:刷新token重试 - if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" { - resp.Body.Close() - token, tokenType, err = s.forceRefreshToken(ctx, account) - if err != nil { - return nil, fmt.Errorf("token refresh failed: %w", err) - } - upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) - if err != nil { - return nil, err - } - resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL) - if err != nil { - return nil, fmt.Errorf("retry request failed: %w", err) - } - defer resp.Body.Close() - } - - // 处理错误响应 + // 处理错误响应(包括401,由后台TokenRefreshService维护token有效性) if resp.StatusCode >= 400 { return s.handleErrorResponse(ctx, resp, c, account) } @@ -619,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str return claude.DefaultBetaHeader } -func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) { - tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - return "", "", err - } - - account.Credentials["access_token"] = tokenInfo.AccessToken - account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) - if tokenInfo.RefreshToken != "" { - account.Credentials["refresh_token"] = tokenInfo.RefreshToken - } - - if err := s.accountRepo.Update(ctx, account); err != nil { - log.Printf("Failed to update account credentials: %v", err) - } - - return tokenInfo.AccessToken, "oauth", nil -} - func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -1053,26 +987,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } defer resp.Body.Close() - // 处理 401 错误:刷新 token 重试(仅 OAuth) - if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" { - resp.Body.Close() - token, tokenType, err = s.forceRefreshToken(ctx, account) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed") - return fmt.Errorf("token refresh failed: %w", err) - } - upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) - if err != nil { - return err - } - resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed") - return fmt.Errorf("retry request failed: %w", err) - } - defer resp.Body.Close() - } - // 读取响应体 respBody, err := io.ReadAll(resp.Body) if err != nil { diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index b5f37e1d..292a9a7b 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -27,4 +27,5 @@ type Services struct { Concurrency *ConcurrencyService Identity *IdentityService Update *UpdateService + TokenRefresh *TokenRefreshService } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go new file mode 100644 index 00000000..7445634a --- /dev/null +++ b/backend/internal/service/token_refresh_service.go @@ -0,0 +1,185 @@ +package service + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "sub2api/internal/config" + "sub2api/internal/model" + "sub2api/internal/service/ports" +) + +// TokenRefreshService OAuth token自动刷新服务 +// 定期检查并刷新即将过期的token +type TokenRefreshService struct { + accountRepo ports.AccountRepository + refreshers []TokenRefresher + cfg *config.TokenRefreshConfig + + stopCh chan struct{} + wg sync.WaitGroup +} + +// NewTokenRefreshService 创建token刷新服务 +func NewTokenRefreshService( + accountRepo ports.AccountRepository, + oauthService *OAuthService, + cfg *config.Config, +) *TokenRefreshService { + s := &TokenRefreshService{ + accountRepo: accountRepo, + cfg: &cfg.TokenRefresh, + stopCh: make(chan struct{}), + } + + // 注册平台特定的刷新器 + s.refreshers = []TokenRefresher{ + NewClaudeTokenRefresher(oauthService), + // 未来可以添加其他平台的刷新器: + // NewOpenAITokenRefresher(...), + // NewGeminiTokenRefresher(...), + } + + return s +} + +// Start 启动后台刷新服务 +func (s *TokenRefreshService) Start() { + if !s.cfg.Enabled { + log.Println("[TokenRefresh] Service disabled by configuration") + return + } + + s.wg.Add(1) + go s.refreshLoop() + + log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)", + s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours) +} + +// Stop 停止刷新服务 +func (s *TokenRefreshService) Stop() { + close(s.stopCh) + s.wg.Wait() + log.Println("[TokenRefresh] Service stopped") +} + +// refreshLoop 刷新循环 +func (s *TokenRefreshService) refreshLoop() { + defer s.wg.Done() + + // 计算检查间隔 + checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute + if checkInterval < time.Minute { + checkInterval = 5 * time.Minute + } + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + // 启动时立即执行一次检查 + s.processRefresh() + + for { + select { + case <-ticker.C: + s.processRefresh() + case <-s.stopCh: + return + } + } +} + +// processRefresh 执行一次刷新检查 +func (s *TokenRefreshService) processRefresh() { + ctx := context.Background() + + // 计算刷新窗口 + refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour)) + + // 获取所有active状态的账号 + accounts, err := s.listActiveAccounts(ctx) + if err != nil { + log.Printf("[TokenRefresh] Failed to list accounts: %v", err) + return + } + + refreshed, failed := 0, 0 + + for i := range accounts { + account := &accounts[i] + + // 遍历所有刷新器,找到能处理此账号的 + for _, refresher := range s.refreshers { + if !refresher.CanRefresh(account) { + continue + } + + // 检查是否需要刷新 + if !refresher.NeedsRefresh(account, refreshWindow) { + continue + } + + // 执行刷新 + if err := s.refreshWithRetry(ctx, account, refresher); err != nil { + log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err) + failed++ + } else { + log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name) + refreshed++ + } + + // 每个账号只由一个refresher处理 + break + } + } + + if refreshed > 0 || failed > 0 { + log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed) + } +} + +// listActiveAccounts 获取所有active状态的账号 +// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的) +func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) { + return s.accountRepo.ListActive(ctx) +} + +// refreshWithRetry 带重试的刷新 +func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error { + var lastErr error + + for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { + newCredentials, err := refresher.Refresh(ctx, account) + if err == nil { + // 刷新成功,更新账号credentials + account.Credentials = model.JSONB(newCredentials) + if err := s.accountRepo.Update(ctx, account); err != nil { + return fmt.Errorf("failed to save credentials: %w", err) + } + return nil + } + + lastErr = err + log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v", + account.ID, attempt, s.cfg.MaxRetries, err) + + // 如果还有重试机会,等待后重试 + if attempt < s.cfg.MaxRetries { + // 指数退避:2^(attempt-1) * baseSeconds + backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1)) + time.Sleep(backoff) + } + } + + // 所有重试都失败,标记账号为error状态 + errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr) + if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { + log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err) + } + + return lastErr +} diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go new file mode 100644 index 00000000..b93f6858 --- /dev/null +++ b/backend/internal/service/token_refresher.go @@ -0,0 +1,90 @@ +package service + +import ( + "context" + "strconv" + "time" + + "sub2api/internal/model" +) + +// TokenRefresher 定义平台特定的token刷新策略接口 +// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini) +type TokenRefresher interface { + // CanRefresh 检查此刷新器是否能处理指定账号 + CanRefresh(account *model.Account) bool + + // NeedsRefresh 检查账号的token是否需要刷新 + NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool + + // Refresh 执行token刷新,返回更新后的credentials + // 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段 + Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) +} + +// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新 +type ClaudeTokenRefresher struct { + oauthService *OAuthService +} + +// NewClaudeTokenRefresher 创建Claude token刷新器 +func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher { + return &ClaudeTokenRefresher{ + oauthService: oauthService, + } +} + +// CanRefresh 检查是否能处理此账号 +// 只处理 anthropic 平台的 oauth 类型账号 +// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新 +func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool { + return account.Platform == model.PlatformAnthropic && + account.Type == model.AccountTypeOAuth +} + +// NeedsRefresh 检查token是否需要刷新 +// 基于 expires_at 字段判断是否在刷新窗口内 +func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { + expiresAtStr := account.GetCredential("expires_at") + if expiresAtStr == "" { + return false + } + + expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) + if err != nil { + return false + } + + expiryTime := time.Unix(expiresAt, 0) + return time.Until(expiryTime) < refreshWindow +} + +// Refresh 执行token刷新 +// 保留原有credentials中的所有字段,只更新token相关字段 +func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) { + tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account) + if err != nil { + return nil, err + } + + // 保留现有credentials中的所有字段 + newCredentials := make(map[string]interface{}) + for k, v := range account.Credentials { + newCredentials[k] = v + } + + // 只更新token相关字段 + // 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型 + newCredentials["access_token"] = tokenInfo.AccessToken + newCredentials["token_type"] = tokenInfo.TokenType + newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10) + newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10) + if tokenInfo.RefreshToken != "" { + newCredentials["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.Scope != "" { + newCredentials["scope"] = tokenInfo.Scope + } + + return newCredentials, nil +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index eca145a5..4d293352 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -33,6 +33,17 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { return NewEmailQueueService(emailService, 3) } +// ProvideTokenRefreshService creates and starts TokenRefreshService +func ProvideTokenRefreshService( + accountRepo ports.AccountRepository, + oauthService *OAuthService, + cfg *config.Config, +) *TokenRefreshService { + svc := NewTokenRefreshService(accountRepo, oauthService, cfg) + svc.Start() + return svc +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -61,6 +72,7 @@ var ProviderSet = wire.NewSet( NewConcurrencyService, NewIdentityService, ProvideUpdateService, + ProvideTokenRefreshService, // Provide the Services container struct wire.Struct(new(Services), "*"),