fix: 修复claude OAuth账户刷新token失败的bug
This commit is contained in:
@@ -199,16 +199,20 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
|||||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
client := s.clientFactory(proxyURL)
|
client := s.clientFactory(proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||||
formData.Set("grant_type", "refresh_token")
|
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||||
formData.Set("refresh_token", refreshToken)
|
reqBody := map[string]any{
|
||||||
formData.Set("client_id", oauth.ClientID)
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refreshToken,
|
||||||
|
"client_id": oauth.ClientID,
|
||||||
|
}
|
||||||
|
|
||||||
var tokenResp oauth.TokenResponse
|
var tokenResp oauth.TokenResponse
|
||||||
|
|
||||||
resp, err := client.R().
|
resp, err := client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
SetFormDataFromValues(formData).
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(reqBody).
|
||||||
SetSuccessResult(&tokenResp).
|
SetSuccessResult(&tokenResp).
|
||||||
Post(s.tokenURL)
|
Post(s.tokenURL)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -34,7 +33,6 @@ type requestCapture struct {
|
|||||||
method string
|
method string
|
||||||
cookies []*http.Cookie
|
cookies []*http.Cookie
|
||||||
body []byte
|
body []byte
|
||||||
formValues url.Values
|
|
||||||
bodyJSON map[string]any
|
bodyJSON map[string]any
|
||||||
contentType string
|
contentType string
|
||||||
}
|
}
|
||||||
@@ -282,24 +280,53 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
|||||||
validate func(captured requestCapture)
|
validate func(captured requestCapture)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "sends_form",
|
name: "sends_json_format",
|
||||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
|
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||||
|
AccessToken: "new_access_token",
|
||||||
|
TokenType: "bearer",
|
||||||
|
ExpiresIn: 28800,
|
||||||
|
RefreshToken: "new_refresh_token",
|
||||||
|
Scope: "user:profile user:inference",
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantResp: &oauth.TokenResponse{
|
||||||
|
AccessToken: "new_access_token",
|
||||||
|
RefreshToken: "new_refresh_token",
|
||||||
},
|
},
|
||||||
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
|
|
||||||
validate: func(captured requestCapture) {
|
validate: func(captured requestCapture) {
|
||||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||||
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
|
// 验证使用 JSON 格式(不是 form 格式)
|
||||||
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
|
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||||
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
|
"expected JSON content-type, got: %s", captured.contentType)
|
||||||
|
// 验证 JSON body 内容
|
||||||
|
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||||
|
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||||
|
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns_new_refresh_token",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
TokenType: "bearer",
|
||||||
|
ExpiresIn: 28800,
|
||||||
|
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantResp: &oauth.TokenResponse{
|
||||||
|
AccessToken: "at",
|
||||||
|
RefreshToken: "rotated_rt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non_200_returns_error",
|
name: "non_200_returns_error",
|
||||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
_, _ = w.Write([]byte("unauthorized"))
|
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@@ -311,8 +338,9 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
|||||||
|
|
||||||
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
captured.method = r.Method
|
captured.method = r.Method
|
||||||
|
captured.contentType = r.Header.Get("Content-Type")
|
||||||
captured.body, _ = io.ReadAll(r.Body)
|
captured.body, _ = io.ReadAll(r.Body)
|
||||||
captured.formValues, _ = url.ParseQuery(string(captured.body))
|
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||||
tt.handler(w, r)
|
tt.handler(w, r)
|
||||||
}))
|
}))
|
||||||
defer s.srv.Close()
|
defer s.srv.Close()
|
||||||
@@ -331,6 +359,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
|||||||
|
|
||||||
require.NoError(s.T(), err)
|
require.NoError(s.T(), err)
|
||||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||||
|
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||||
if tt.validate != nil {
|
if tt.validate != nil {
|
||||||
tt.validate(captured)
|
tt.validate(captured)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type Account struct {
|
type Account struct {
|
||||||
ID int64
|
ID int64
|
||||||
@@ -82,12 +85,25 @@ func (a *Account) GetCredential(key string) string {
|
|||||||
if a.Credentials == nil {
|
if a.Credentials == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if v, ok := a.Credentials[key]; ok {
|
v, ok := a.Credentials[key]
|
||||||
if s, ok := v.(string); ok {
|
if !ok || v == nil {
|
||||||
return s
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串)
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
return val
|
||||||
|
case float64:
|
||||||
|
// JSON 解析后数字默认为 float64
|
||||||
|
return strconv.FormatInt(int64(val), 10)
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(val, 10)
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(val)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) GetModelMapping() map[string]string {
|
func (a *Account) GetModelMapping() map[string]string {
|
||||||
|
|||||||
Reference in New Issue
Block a user