From 2101f1d1c899520ed273042a9bcb8717395a7f1d Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 27 Dec 2025 13:50:35 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dclaude=20OAuth?= =?UTF-8?q?=E8=B4=A6=E6=88=B7=E5=88=B7=E6=96=B0token=E5=A4=B1=E8=B4=A5?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repository/claude_oauth_service.go | 14 ++++-- .../repository/claude_oauth_service_test.go | 49 +++++++++++++++---- backend/internal/service/account.go | 28 ++++++++--- 3 files changed, 70 insertions(+), 21 deletions(-) diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 005b1679..75699712 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -199,16 +199,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { client := s.clientFactory(proxyURL) - formData := url.Values{} - formData.Set("grant_type", "refresh_token") - formData.Set("refresh_token", refreshToken) - formData.Set("client_id", oauth.ClientID) + // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致) + // Anthropic OAuth API 期望 JSON 格式的请求体 + reqBody := map[string]any{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": oauth.ClientID, + } var tokenResp oauth.TokenResponse resp, err := client.R(). SetContext(ctx). - SetFormDataFromValues(formData). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). SetSuccessResult(&tokenResp). Post(s.tokenURL) diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 1d466f48..dd9c48b3 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "strings" "testing" @@ -34,7 +33,6 @@ type requestCapture struct { method string cookies []*http.Cookie body []byte - formValues url.Values bodyJSON map[string]any contentType string } @@ -282,24 +280,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { validate func(captured requestCapture) }{ { - name: "sends_form", + name: "sends_json_format", handler: func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600}) + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "new_access_token", + TokenType: "bearer", + ExpiresIn: 28800, + RefreshToken: "new_refresh_token", + Scope: "user:profile user:inference", + }) + }, + wantResp: &oauth.TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", }, - wantResp: &oauth.TokenResponse{AccessToken: "at2"}, validate: func(captured requestCapture) { require.Equal(s.T(), http.MethodPost, captured.method, "expected POST") - require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type")) - require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token")) - require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id")) + // 验证使用 JSON 格式(不是 form 格式) + require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), + "expected JSON content-type, got: %s", captured.contentType) + // 验证 JSON body 内容 + require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"]) + require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"]) + require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) + }, + }, + { + name: "returns_new_refresh_token", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "at", + TokenType: "bearer", + ExpiresIn: 28800, + RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens + }) + }, + wantResp: &oauth.TokenResponse{ + AccessToken: "at", + RefreshToken: "rotated_rt", }, }, { name: "non_200_returns_error", handler: func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte("unauthorized")) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) }, wantErr: true, }, @@ -311,8 +338,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured.method = r.Method + captured.contentType = r.Header.Get("Content-Type") captured.body, _ = io.ReadAll(r.Body) - captured.formValues, _ = url.ParseQuery(string(captured.body)) + _ = json.Unmarshal(captured.body, &captured.bodyJSON) tt.handler(w, r) })) defer s.srv.Close() @@ -331,6 +359,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.NoError(s.T(), err) require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken) + require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken) if tt.validate != nil { tt.validate(captured) } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 51b7a4f1..f740cb90 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1,6 +1,9 @@ package service -import "time" +import ( + "strconv" + "time" +) type Account struct { ID int64 @@ -82,12 +85,25 @@ func (a *Account) GetCredential(key string) string { if a.Credentials == nil { return "" } - if v, ok := a.Credentials[key]; ok { - if s, ok := v.(string); ok { - return s - } + v, ok := a.Credentials[key] + if !ok || v == nil { + return "" + } + + // 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串) + switch val := v.(type) { + case string: + return val + case float64: + // JSON 解析后数字默认为 float64 + return strconv.FormatInt(int64(val), 10) + case int64: + return strconv.FormatInt(val, 10) + case int: + return strconv.Itoa(val) + default: + return "" } - return "" } func (a *Account) GetModelMapping() map[string]string {