From 0d5a8a95c890b910c175407bf1ca65941ff4bcc0 Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 27 Dec 2025 20:13:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dclaude=20token?= =?UTF-8?q?=E5=88=B7=E6=96=B0=E5=A4=B1=E6=95=88=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/token_refresher.go | 23 +- .../internal/service/token_refresher_test.go | 214 ++++++++++++++++++ 2 files changed, 228 insertions(+), 9 deletions(-) create mode 100644 backend/internal/service/token_refresher_test.go diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 8857a416..a43a525e 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -43,18 +43,23 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { - expiresAtStr := account.GetCredential("expires_at") - if expiresAtStr == "" { + var expiresAt int64 + + // 方式1: 通过 GetCredential 获取(处理字符串和部分数字类型) + if s := account.GetCredential("expires_at"); s != "" { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return false + } + expiresAt = v + } else if v, ok := account.Credentials["expires_at"].(float64); ok { + // 方式2: 直接获取 float64(处理某些 JSON 解码器将数字解析为 float64 的情况) + expiresAt = int64(v) + } else { return false } - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err != nil { - return false - } - - expiryTime := time.Unix(expiresAt, 0) - return time.Until(expiryTime) < refreshWindow + return time.Until(time.Unix(expiresAt, 0)) < refreshWindow } // Refresh 执行token刷新 diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go new file mode 100644 index 00000000..c00fcfa3 --- /dev/null +++ b/backend/internal/service/token_refresher_test.go @@ -0,0 +1,214 @@ +//go:build unit + +package service + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + tests := []struct { + name string + credentials map[string]any + wantRefresh bool + }{ + { + name: "expires_at as string - expired", + credentials: map[string]any{ + "expires_at": "1000", // 1970-01-01 00:16:40 UTC, 已过期 + }, + wantRefresh: true, + }, + { + name: "expires_at as float64 - expired", + credentials: map[string]any{ + "expires_at": float64(1000), // 数字类型,已过期 + }, + wantRefresh: true, + }, + { + name: "expires_at as string - far future", + credentials: map[string]any{ + "expires_at": "9999999999", // 远未来 + }, + wantRefresh: false, + }, + { + name: "expires_at as float64 - far future", + credentials: map[string]any{ + "expires_at": float64(9999999999), // 远未来,数字类型 + }, + wantRefresh: false, + }, + { + name: "expires_at missing", + credentials: map[string]any{}, + wantRefresh: false, + }, + { + name: "expires_at is nil", + credentials: map[string]any{ + "expires_at": nil, + }, + wantRefresh: false, + }, + { + name: "expires_at is invalid string", + credentials: map[string]any{ + "expires_at": "invalid", + }, + wantRefresh: false, + }, + { + name: "credentials is nil", + credentials: nil, + wantRefresh: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.Equal(t, tt.wantRefresh, got) + }) + } +} + +func TestClaudeTokenRefresher_NeedsRefresh_WithinWindow(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + // 设置一个在刷新窗口内的时间(当前时间 + 15分钟) + expiresAt := time.Now().Add(15 * time.Minute).Unix() + + tests := []struct { + name string + credentials map[string]any + }{ + { + name: "string type - within refresh window", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(expiresAt, 10), + }, + }, + { + name: "float64 type - within refresh window", + credentials: map[string]any{ + "expires_at": float64(expiresAt), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.True(t, got, "should need refresh when within window") + }) + } +} + +func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + // 设置一个在刷新窗口外的时间(当前时间 + 1小时) + expiresAt := time.Now().Add(1 * time.Hour).Unix() + + tests := []struct { + name string + credentials map[string]any + }{ + { + name: "string type - outside refresh window", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(expiresAt, 10), + }, + }, + { + name: "float64 type - outside refresh window", + credentials: map[string]any{ + "expires_at": float64(expiresAt), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.False(t, got, "should not need refresh when outside window") + }) + } +} + +func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + + tests := []struct { + name string + platform string + accType string + want bool + }{ + { + name: "anthropic oauth - can refresh", + platform: PlatformAnthropic, + accType: AccountTypeOAuth, + want: true, + }, + { + name: "anthropic api-key - cannot refresh", + platform: PlatformAnthropic, + accType: AccountTypeApiKey, + want: false, + }, + { + name: "openai oauth - cannot refresh", + platform: PlatformOpenAI, + accType: AccountTypeOAuth, + want: false, + }, + { + name: "gemini oauth - cannot refresh", + platform: PlatformGemini, + accType: AccountTypeOAuth, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: tt.platform, + Type: tt.accType, + } + + got := refresher.CanRefresh(account) + require.Equal(t, tt.want, got) + }) + } +}