fix: 修复Oauth账号自动刷新token失败的bug

This commit is contained in:
shaw
2025-12-20 13:01:58 +08:00
parent bb500b7b2a
commit adebd941e1
8 changed files with 333 additions and 99 deletions

View File

@@ -73,6 +73,10 @@ func provideCleanup(
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil

View File

@@ -106,6 +106,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
@@ -138,6 +139,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
Concurrency: concurrencyService,
Identity: identityService,
Update: updateService,
TokenRefresh: tokenRefreshService,
}
repositories := &repository.Repositories{
User: userRepository,
@@ -187,6 +189,10 @@ func provideCleanup(
name string
fn func() error
}{
{"TokenRefreshService", func() error {
services.TokenRefresh.Stop()
return nil
}},
{"PricingService", func() error {
services.Pricing.Stop()
return nil

View File

@@ -8,15 +8,30 @@ import (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
JWT JWTConfig `mapstructure:"jwt"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
Gateway GatewayConfig `mapstructure:"gateway"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
Enabled bool `mapstructure:"enabled"`
// 检查间隔(分钟)
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
// 提前刷新时间小时在token过期前多久开始刷新
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
// 最大重试次数
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
}
type PricingConfig struct {
@@ -192,6 +207,13 @@ func setDefaults() {
// Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
}
func (c *Config) Validate() error {

View File

@@ -13,7 +13,6 @@ import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -34,7 +33,6 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
)
// allowedHeaders 白名单headers参考CRS项目
@@ -358,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
expiresAtStr := account.GetCredential("expires_at")
// 检查是否需要刷新
needRefresh := false
if expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
needRefresh = true
}
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
if needRefresh || accessToken == "" {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", fmt.Errorf("refresh token failed: %w", err)
}
// 更新账号凭证
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
@@ -442,25 +413,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}
defer resp.Body.Close()
// 处理401错误刷新token重试
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("token refresh failed: %w", err)
}
upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 处理错误响应
// 处理错误响应包括401由后台TokenRefreshService维护token有效性)
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -619,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader
}
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", err
}
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
@@ -1053,26 +987,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
defer resp.Body.Close()
// 处理 401 错误:刷新 token 重试(仅 OAuth
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
return fmt.Errorf("token refresh failed: %w", err)
}
upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return err
}
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
return fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {

View File

@@ -27,4 +27,5 @@ type Services struct {
Concurrency *ConcurrencyService
Identity *IdentityService
Update *UpdateService
TokenRefresh *TokenRefreshService
}

View File

@@ -0,0 +1,185 @@
package service
import (
"context"
"fmt"
"log"
"sync"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/service/ports"
)
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
accountRepo ports.AccountRepository
refreshers []TokenRefresher
cfg *config.TokenRefreshConfig
stopCh chan struct{}
wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
func NewTokenRefreshService(
accountRepo ports.AccountRepository,
oauthService *OAuthService,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
accountRepo: accountRepo,
cfg: &cfg.TokenRefresh,
stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
// 未来可以添加其他平台的刷新器:
// NewOpenAITokenRefresher(...),
// NewGeminiTokenRefresher(...),
}
return s
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
log.Println("[TokenRefresh] Service disabled by configuration")
return
}
s.wg.Add(1)
go s.refreshLoop()
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
}
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[TokenRefresh] Service stopped")
}
// refreshLoop 刷新循环
func (s *TokenRefreshService) refreshLoop() {
defer s.wg.Done()
// 计算检查间隔
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
if checkInterval < time.Minute {
checkInterval = 5 * time.Minute
}
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
// 启动时立即执行一次检查
s.processRefresh()
for {
select {
case <-ticker.C:
s.processRefresh()
case <-s.stopCh:
return
}
}
}
// processRefresh 执行一次刷新检查
func (s *TokenRefreshService) processRefresh() {
ctx := context.Background()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
return
}
refreshed, failed := 0, 0
for i := range accounts {
account := &accounts[i]
// 遍历所有刷新器,找到能处理此账号的
for _, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
// 检查是否需要刷新
if !refresher.NeedsRefresh(account, refreshWindow) {
continue
}
// 执行刷新
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
failed++
} else {
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
refreshed++
}
// 每个账号只由一个refresher处理
break
}
}
if refreshed > 0 || failed > 0 {
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
}
}
// listActiveAccounts 获取所有active状态的账号
// 使用ListActive确保刷新所有活跃账号的token包括临时禁用的
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
return s.accountRepo.ListActive(ctx)
}
// refreshWithRetry 带重试的刷新
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
var lastErr error
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
newCredentials, err := refresher.Refresh(ctx, account)
if err == nil {
// 刷新成功更新账号credentials
account.Credentials = model.JSONB(newCredentials)
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
return nil
}
lastErr = err
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
account.ID, attempt, s.cfg.MaxRetries, err)
// 如果还有重试机会,等待后重试
if attempt < s.cfg.MaxRetries {
// 指数退避2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
time.Sleep(backoff)
}
}
// 所有重试都失败标记账号为error状态
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
}
return lastErr
}

View File

@@ -0,0 +1,90 @@
package service
import (
"context"
"strconv"
"time"
"sub2api/internal/model"
)
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台Anthropic/OpenAI/Gemini
type TokenRefresher interface {
// CanRefresh 检查此刷新器是否能处理指定账号
CanRefresh(account *model.Account) bool
// NeedsRefresh 检查账号的token是否需要刷新
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
// Refresh 执行token刷新返回更新后的credentials
// 注意返回的map应该保留原有credentials中的所有字段只更新token相关字段
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
}
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
type ClaudeTokenRefresher struct {
oauthService *OAuthService
}
// NewClaudeTokenRefresher 创建Claude token刷新器
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
return &ClaudeTokenRefresher{
oauthService: oauthService,
}
}
// CanRefresh 检查是否能处理此账号
// 只处理 anthropic 平台的 oauth 类型账号
// setup-token 虽然也是OAuth但有效期1年不需要频繁刷新
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
return account.Platform == model.PlatformAnthropic &&
account.Type == model.AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
// 基于 expires_at 字段判断是否在刷新窗口内
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
expiresAtStr := account.GetCredential("expires_at")
if expiresAtStr == "" {
return false
}
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err != nil {
return false
}
expiryTime := time.Unix(expiresAt, 0)
return time.Until(expiryTime) < refreshWindow
}
// Refresh 执行token刷新
// 保留原有credentials中的所有字段只更新token相关字段
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
// 保留现有credentials中的所有字段
newCredentials := make(map[string]interface{})
for k, v := range account.Credentials {
newCredentials[k] = v
}
// 只更新token相关字段
// 注意expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.Scope != "" {
newCredentials["scope"] = tokenInfo.Scope
}
return newCredentials, nil
}

View File

@@ -33,6 +33,17 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3)
}
// ProvideTokenRefreshService creates and starts TokenRefreshService
func ProvideTokenRefreshService(
accountRepo ports.AccountRepository,
oauthService *OAuthService,
cfg *config.Config,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
svc.Start()
return svc
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -61,6 +72,7 @@ var ProviderSet = wire.NewSet(
NewConcurrencyService,
NewIdentityService,
ProvideUpdateService,
ProvideTokenRefreshService,
// Provide the Services container struct
wire.Struct(new(Services), "*"),