fix: 修复Oauth账号自动刷新token失败的bug

This commit is contained in:
shaw
2025-12-20 13:01:58 +08:00
parent bb500b7b2a
commit adebd941e1
8 changed files with 333 additions and 99 deletions

View File

@@ -13,7 +13,6 @@ import (
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -34,7 +33,6 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
tokenRefreshBuffer = 5 * 60 // 提前5分钟刷新token
)
// allowedHeaders 白名单headers参考CRS项目
@@ -358,37 +356,10 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
expiresAtStr := account.GetCredential("expires_at")
// 检查是否需要刷新
needRefresh := false
if expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+tokenRefreshBuffer > expiresAt {
needRefresh = true
}
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
if needRefresh || accessToken == "" {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", fmt.Errorf("refresh token failed: %w", err)
}
// 更新账号凭证
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
@@ -442,25 +413,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}
defer resp.Body.Close()
// 处理401错误刷新token重试
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
return nil, fmt.Errorf("token refresh failed: %w", err)
}
upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return nil, err
}
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 处理错误响应
// 处理错误响应包括401由后台TokenRefreshService维护token有效性)
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -619,25 +572,6 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
return claude.DefaultBetaHeader
}
func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.Account) (string, string, error) {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return "", "", err
}
account.Credentials["access_token"] = tokenInfo.AccessToken
account.Credentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
if tokenInfo.RefreshToken != "" {
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
return tokenInfo.AccessToken, "oauth", nil
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
@@ -1053,26 +987,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
defer resp.Body.Close()
// 处理 401 错误:刷新 token 重试(仅 OAuth
if resp.StatusCode == http.StatusUnauthorized && tokenType == "oauth" {
resp.Body.Close()
token, tokenType, err = s.forceRefreshToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
return fmt.Errorf("token refresh failed: %w", err)
}
upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil {
return err
}
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
return fmt.Errorf("retry request failed: %w", err)
}
defer resp.Body.Close()
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {