fix: 修复Oauth账号自动刷新token失败的bug
This commit is contained in:
@@ -73,6 +73,10 @@ func provideCleanup(
|
|||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}{
|
||||||
|
{"TokenRefreshService", func() error {
|
||||||
|
services.TokenRefresh.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
services.Pricing.Stop()
|
services.Pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||||
|
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, configConfig)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||||
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
|
||||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler)
|
||||||
@@ -138,6 +139,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
Concurrency: concurrencyService,
|
Concurrency: concurrencyService,
|
||||||
Identity: identityService,
|
Identity: identityService,
|
||||||
Update: updateService,
|
Update: updateService,
|
||||||
|
TokenRefresh: tokenRefreshService,
|
||||||
}
|
}
|
||||||
repositories := &repository.Repositories{
|
repositories := &repository.Repositories{
|
||||||
User: userRepository,
|
User: userRepository,
|
||||||
@@ -187,6 +189,10 @@ func provideCleanup(
|
|||||||
name string
|
name string
|
||||||
fn func() error
|
fn func() error
|
||||||
}{
|
}{
|
||||||
|
{"TokenRefreshService", func() error {
|
||||||
|
services.TokenRefresh.Stop()
|
||||||
|
return nil
|
||||||
|
}},
|
||||||
{"PricingService", func() error {
|
{"PricingService", func() error {
|
||||||
services.Pricing.Stop()
|
services.Pricing.Stop()
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,15 +8,30 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Default DefaultConfig `mapstructure:"default"`
|
Default DefaultConfig `mapstructure:"default"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Pricing PricingConfig `mapstructure:"pricing"`
|
Pricing PricingConfig `mapstructure:"pricing"`
|
||||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||||
|
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRefreshConfig OAuth token自动刷新配置
|
||||||
|
type TokenRefreshConfig struct {
|
||||||
|
// 是否启用自动刷新
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
// 检查间隔(分钟)
|
||||||
|
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
|
||||||
|
// 提前刷新时间(小时),在token过期前多久开始刷新
|
||||||
|
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
|
||||||
|
// 最大重试次数
|
||||||
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
|
// 重试退避基础时间(秒)
|
||||||
|
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PricingConfig struct {
|
type PricingConfig struct {
|
||||||
@@ -192,6 +207,13 @@ func setDefaults() {
|
|||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久
|
||||||
|
|
||||||
|
// TokenRefresh
|
||||||
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||||
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -34,7 +33,6 @@ const (
|
|||||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// allowedHeaders 白名单headers(参考CRS项目)
|
// allowedHeaders 白名单headers(参考CRS项目)
|
||||||
@@ -358,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
|||||||
|
|
||||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
expiresAtStr := account.GetCredential("expires_at")
|
if accessToken == "" {
|
||||||
|
return "", "", errors.New("access_token not found in credentials")
|
||||||
// 检查是否需要刷新
|
|
||||||
needRefresh := false
|
|
||||||
if expiresAtStr != "" {
|
|
||||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
|
||||||
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
|
|
||||||
needRefresh = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// Token刷新由后台 TokenRefreshService 处理,此处只返回当前token
|
||||||
if needRefresh || accessToken == "" {
|
|
||||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", fmt.Errorf("refresh token failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新账号凭证
|
|
||||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
|
||||||
log.Printf("Failed to update account credentials: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo.AccessToken, "oauth", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return accessToken, "oauth", nil
|
return accessToken, "oauth", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -442,25 +413,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 处理401错误:刷新token重试
|
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
|
||||||
resp.Body.Close()
|
|
||||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
|
||||||
}
|
|
||||||
upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("retry request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理错误响应
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return s.handleErrorResponse(ctx, resp, c, account)
|
return s.handleErrorResponse(ctx, resp, c, account)
|
||||||
}
|
}
|
||||||
@@ -619,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
|||||||
return claude.DefaultBetaHeader
|
return claude.DefaultBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
|
|
||||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Credentials["access_token"] = tokenInfo.AccessToken
|
|
||||||
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
|
||||||
if tokenInfo.RefreshToken != "" {
|
|
||||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
|
||||||
log.Printf("Failed to update account credentials: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenInfo.AccessToken, "oauth", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
@@ -1053,26 +987,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 处理 401 错误:刷新 token 重试(仅 OAuth)
|
|
||||||
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
|
|
||||||
resp.Body.Close()
|
|
||||||
token, tokenType, err = s.forceRefreshToken(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
|
|
||||||
return fmt.Errorf("token refresh failed: %w", err)
|
|
||||||
}
|
|
||||||
upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
|
|
||||||
return fmt.Errorf("retry request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,4 +27,5 @@ type Services struct {
|
|||||||
Concurrency *ConcurrencyService
|
Concurrency *ConcurrencyService
|
||||||
Identity *IdentityService
|
Identity *IdentityService
|
||||||
Update *UpdateService
|
Update *UpdateService
|
||||||
|
TokenRefresh *TokenRefreshService
|
||||||
}
|
}
|
||||||
|
|||||||
185
backend/internal/service/token_refresh_service.go
Normal file
185
backend/internal/service/token_refresh_service.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/config"
|
||||||
|
"sub2api/internal/model"
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenRefreshService OAuth token自动刷新服务
|
||||||
|
// 定期检查并刷新即将过期的token
|
||||||
|
type TokenRefreshService struct {
|
||||||
|
accountRepo ports.AccountRepository
|
||||||
|
refreshers []TokenRefresher
|
||||||
|
cfg *config.TokenRefreshConfig
|
||||||
|
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenRefreshService 创建token刷新服务
|
||||||
|
func NewTokenRefreshService(
|
||||||
|
accountRepo ports.AccountRepository,
|
||||||
|
oauthService *OAuthService,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *TokenRefreshService {
|
||||||
|
s := &TokenRefreshService{
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
cfg: &cfg.TokenRefresh,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册平台特定的刷新器
|
||||||
|
s.refreshers = []TokenRefresher{
|
||||||
|
NewClaudeTokenRefresher(oauthService),
|
||||||
|
// 未来可以添加其他平台的刷新器:
|
||||||
|
// NewOpenAITokenRefresher(...),
|
||||||
|
// NewGeminiTokenRefresher(...),
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动后台刷新服务
|
||||||
|
func (s *TokenRefreshService) Start() {
|
||||||
|
if !s.cfg.Enabled {
|
||||||
|
log.Println("[TokenRefresh] Service disabled by configuration")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.refreshLoop()
|
||||||
|
|
||||||
|
log.Printf("[TokenRefresh] Service started (check every %d minutes, refresh %v hours before expiry)",
|
||||||
|
s.cfg.CheckIntervalMinutes, s.cfg.RefreshBeforeExpiryHours)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止刷新服务
|
||||||
|
func (s *TokenRefreshService) Stop() {
|
||||||
|
close(s.stopCh)
|
||||||
|
s.wg.Wait()
|
||||||
|
log.Println("[TokenRefresh] Service stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshLoop 刷新循环
|
||||||
|
func (s *TokenRefreshService) refreshLoop() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
// 计算检查间隔
|
||||||
|
checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
|
||||||
|
if checkInterval < time.Minute {
|
||||||
|
checkInterval = 5 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// 启动时立即执行一次检查
|
||||||
|
s.processRefresh()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.processRefresh()
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processRefresh 执行一次刷新检查
|
||||||
|
func (s *TokenRefreshService) processRefresh() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 计算刷新窗口
|
||||||
|
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
|
||||||
|
|
||||||
|
// 获取所有active状态的账号
|
||||||
|
accounts, err := s.listActiveAccounts(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[TokenRefresh] Failed to list accounts: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, failed := 0, 0
|
||||||
|
|
||||||
|
for i := range accounts {
|
||||||
|
account := &accounts[i]
|
||||||
|
|
||||||
|
// 遍历所有刷新器,找到能处理此账号的
|
||||||
|
for _, refresher := range s.refreshers {
|
||||||
|
if !refresher.CanRefresh(account) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否需要刷新
|
||||||
|
if !refresher.NeedsRefresh(account, refreshWindow) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行刷新
|
||||||
|
if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
|
||||||
|
log.Printf("[TokenRefresh] Account %d (%s) failed: %v", account.ID, account.Name, err)
|
||||||
|
failed++
|
||||||
|
} else {
|
||||||
|
log.Printf("[TokenRefresh] Account %d (%s) refreshed successfully", account.ID, account.Name)
|
||||||
|
refreshed++
|
||||||
|
}
|
||||||
|
|
||||||
|
// 每个账号只由一个refresher处理
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshed > 0 || failed > 0 {
|
||||||
|
log.Printf("[TokenRefresh] Cycle complete: %d refreshed, %d failed", refreshed, failed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listActiveAccounts 获取所有active状态的账号
|
||||||
|
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
|
||||||
|
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
|
||||||
|
return s.accountRepo.ListActive(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshWithRetry 带重试的刷新
|
||||||
|
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||||
|
newCredentials, err := refresher.Refresh(ctx, account)
|
||||||
|
if err == nil {
|
||||||
|
// 刷新成功,更新账号credentials
|
||||||
|
account.Credentials = model.JSONB(newCredentials)
|
||||||
|
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||||
|
return fmt.Errorf("failed to save credentials: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||||
|
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||||
|
|
||||||
|
// 如果还有重试机会,等待后重试
|
||||||
|
if attempt < s.cfg.MaxRetries {
|
||||||
|
// 指数退避:2^(attempt-1) * baseSeconds
|
||||||
|
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
|
||||||
|
time.Sleep(backoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有重试都失败,标记账号为error状态
|
||||||
|
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||||
|
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||||
|
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
90
backend/internal/service/token_refresher.go
Normal file
90
backend/internal/service/token_refresher.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenRefresher 定义平台特定的token刷新策略接口
|
||||||
|
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
|
||||||
|
type TokenRefresher interface {
|
||||||
|
// CanRefresh 检查此刷新器是否能处理指定账号
|
||||||
|
CanRefresh(account *model.Account) bool
|
||||||
|
|
||||||
|
// NeedsRefresh 检查账号的token是否需要刷新
|
||||||
|
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
|
||||||
|
|
||||||
|
// Refresh 执行token刷新,返回更新后的credentials
|
||||||
|
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||||
|
Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||||
|
type ClaudeTokenRefresher struct {
|
||||||
|
oauthService *OAuthService
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeTokenRefresher 创建Claude token刷新器
|
||||||
|
func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||||
|
return &ClaudeTokenRefresher{
|
||||||
|
oauthService: oauthService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanRefresh 检查是否能处理此账号
|
||||||
|
// 只处理 anthropic 平台的 oauth 类型账号
|
||||||
|
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||||
|
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||||
|
return account.Platform == model.PlatformAnthropic &&
|
||||||
|
account.Type == model.AccountTypeOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsRefresh 检查token是否需要刷新
|
||||||
|
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||||
|
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||||
|
expiresAtStr := account.GetCredential("expires_at")
|
||||||
|
if expiresAtStr == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
expiryTime := time.Unix(expiresAt, 0)
|
||||||
|
return time.Until(expiryTime) < refreshWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh 执行token刷新
|
||||||
|
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||||
|
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]interface{}, error) {
|
||||||
|
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保留现有credentials中的所有字段
|
||||||
|
newCredentials := make(map[string]interface{})
|
||||||
|
for k, v := range account.Credentials {
|
||||||
|
newCredentials[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只更新token相关字段
|
||||||
|
// 注意:expires_at 和 expires_in 必须存为字符串,因为 GetCredential 只返回 string 类型
|
||||||
|
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||||
|
newCredentials["token_type"] = tokenInfo.TokenType
|
||||||
|
newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
|
||||||
|
newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
|
||||||
|
if tokenInfo.RefreshToken != "" {
|
||||||
|
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||||
|
}
|
||||||
|
if tokenInfo.Scope != "" {
|
||||||
|
newCredentials["scope"] = tokenInfo.Scope
|
||||||
|
}
|
||||||
|
|
||||||
|
return newCredentials, nil
|
||||||
|
}
|
||||||
@@ -33,6 +33,17 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
|||||||
return NewEmailQueueService(emailService, 3)
|
return NewEmailQueueService(emailService, 3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideTokenRefreshService creates and starts TokenRefreshService
|
||||||
|
func ProvideTokenRefreshService(
|
||||||
|
accountRepo ports.AccountRepository,
|
||||||
|
oauthService *OAuthService,
|
||||||
|
cfg *config.Config,
|
||||||
|
) *TokenRefreshService {
|
||||||
|
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||||
|
svc.Start()
|
||||||
|
return svc
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
@@ -61,6 +72,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewConcurrencyService,
|
NewConcurrencyService,
|
||||||
NewIdentityService,
|
NewIdentityService,
|
||||||
ProvideUpdateService,
|
ProvideUpdateService,
|
||||||
|
ProvideTokenRefreshService,
|
||||||
|
|
||||||
// Provide the Services container struct
|
// Provide the Services container struct
|
||||||
wire.Struct(new(Services), "*"),
|
wire.Struct(new(Services), "*"),
|
||||||
|
|||||||
Reference in New Issue
Block a user