diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 75699712..f7ff2341 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -145,7 +145,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe return fullCode, nil } -func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { +func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { client := s.clientFactory(proxyURL) // Parse code which may contain state in format "authCode#state" @@ -168,6 +168,11 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["state"] = codeState } + // Setup token requires longer expiration (1 year) + if isSetupToken { + reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds + } + reqBodyJSON, _ := json.Marshal(reqBody) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index dd9c48b3..3295c222 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -191,12 +191,13 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { tests := []struct { - name string - handler http.HandlerFunc - code string - wantErr bool - wantResp *oauth.TokenResponse - validate func(captured requestCapture) + name string + handler http.HandlerFunc + code string + isSetupToken bool + wantErr bool + wantResp *oauth.TokenResponse + validate func(captured requestCapture) }{ { name: "sends_state_when_embedded", @@ -210,7 +211,8 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { Scope: "s", }) }, - code: "AUTH#STATE2", + code: "AUTH#STATE2", + isSetupToken: false, wantResp: &oauth.TokenResponse{ AccessToken: "at", RefreshToken: "rt", @@ -223,6 +225,29 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"]) require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"]) + // Regular OAuth should not include expires_in + require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in") + }, + }, + { + name: "setup_token_includes_expires_in", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "at", + TokenType: "bearer", + ExpiresIn: 31536000, + }) + }, + code: "AUTH", + isSetupToken: true, + wantResp: &oauth.TokenResponse{ + AccessToken: "at", + }, + validate: func(captured requestCapture) { + // Setup token should include expires_in with 1 year value + require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"], + "setup token should include expires_in: 31536000") }, }, { @@ -231,8 +256,9 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte("bad request")) }, - code: "AUTH", - wantErr: true, + code: "AUTH", + isSetupToken: false, + wantErr: true, }, } @@ -254,7 +280,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { s.client = client s.client.tokenURL = s.srv.URL - resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "") + resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) if tt.wantErr { require.Error(s.T(), err) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index f4c149ac..0039cb44 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -20,7 +20,7 @@ type OpenAIOAuthClient interface { type ClaudeOAuthClient interface { GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) - ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) + ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) } @@ -142,8 +142,11 @@ func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInpu } } + // Determine if this is a setup token (scope is inference only) + isSetupToken := session.Scope == oauth.ScopeInference + // Exchange code for token - tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL) + tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL, isSetupToken) if err != nil { return nil, err } @@ -172,10 +175,12 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( } } - // Determine scope + // Determine scope and if this is a setup token scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) + isSetupToken := false if input.Scope == "inference" { scope = oauth.ScopeInference + isSetupToken = true } // Step 1: Get organization info using sessionKey @@ -203,7 +208,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( } // Step 4: Exchange code for token - tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL) + tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL, isSetupToken) if err != nil { return nil, fmt.Errorf("failed to exchange code: %w", err) } @@ -228,8 +233,8 @@ func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, org } // exchangeCodeForToken exchanges authorization code for tokens -func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) { - tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL) +func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*TokenInfo, error) { + tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL, isSetupToken) if err != nil { return nil, err } diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 8857a416..a43a525e 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -43,18 +43,23 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { // NeedsRefresh 检查token是否需要刷新 // 基于 expires_at 字段判断是否在刷新窗口内 func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { - expiresAtStr := account.GetCredential("expires_at") - if expiresAtStr == "" { + var expiresAt int64 + + // 方式1: 通过 GetCredential 获取(处理字符串和部分数字类型) + if s := account.GetCredential("expires_at"); s != "" { + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return false + } + expiresAt = v + } else if v, ok := account.Credentials["expires_at"].(float64); ok { + // 方式2: 直接获取 float64(处理某些 JSON 解码器将数字解析为 float64 的情况) + expiresAt = int64(v) + } else { return false } - expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) - if err != nil { - return false - } - - expiryTime := time.Unix(expiresAt, 0) - return time.Until(expiryTime) < refreshWindow + return time.Until(time.Unix(expiresAt, 0)) < refreshWindow } // Refresh 执行token刷新 diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go new file mode 100644 index 00000000..c00fcfa3 --- /dev/null +++ b/backend/internal/service/token_refresher_test.go @@ -0,0 +1,214 @@ +//go:build unit + +package service + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + tests := []struct { + name string + credentials map[string]any + wantRefresh bool + }{ + { + name: "expires_at as string - expired", + credentials: map[string]any{ + "expires_at": "1000", // 1970-01-01 00:16:40 UTC, 已过期 + }, + wantRefresh: true, + }, + { + name: "expires_at as float64 - expired", + credentials: map[string]any{ + "expires_at": float64(1000), // 数字类型,已过期 + }, + wantRefresh: true, + }, + { + name: "expires_at as string - far future", + credentials: map[string]any{ + "expires_at": "9999999999", // 远未来 + }, + wantRefresh: false, + }, + { + name: "expires_at as float64 - far future", + credentials: map[string]any{ + "expires_at": float64(9999999999), // 远未来,数字类型 + }, + wantRefresh: false, + }, + { + name: "expires_at missing", + credentials: map[string]any{}, + wantRefresh: false, + }, + { + name: "expires_at is nil", + credentials: map[string]any{ + "expires_at": nil, + }, + wantRefresh: false, + }, + { + name: "expires_at is invalid string", + credentials: map[string]any{ + "expires_at": "invalid", + }, + wantRefresh: false, + }, + { + name: "credentials is nil", + credentials: nil, + wantRefresh: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.Equal(t, tt.wantRefresh, got) + }) + } +} + +func TestClaudeTokenRefresher_NeedsRefresh_WithinWindow(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + // 设置一个在刷新窗口内的时间(当前时间 + 15分钟) + expiresAt := time.Now().Add(15 * time.Minute).Unix() + + tests := []struct { + name string + credentials map[string]any + }{ + { + name: "string type - within refresh window", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(expiresAt, 10), + }, + }, + { + name: "float64 type - within refresh window", + credentials: map[string]any{ + "expires_at": float64(expiresAt), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.True(t, got, "should need refresh when within window") + }) + } +} + +func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + refreshWindow := 30 * time.Minute + + // 设置一个在刷新窗口外的时间(当前时间 + 1小时) + expiresAt := time.Now().Add(1 * time.Hour).Unix() + + tests := []struct { + name string + credentials map[string]any + }{ + { + name: "string type - outside refresh window", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(expiresAt, 10), + }, + }, + { + name: "float64 type - outside refresh window", + credentials: map[string]any{ + "expires_at": float64(expiresAt), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + + got := refresher.NeedsRefresh(account, refreshWindow) + require.False(t, got, "should not need refresh when outside window") + }) + } +} + +func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { + refresher := &ClaudeTokenRefresher{} + + tests := []struct { + name string + platform string + accType string + want bool + }{ + { + name: "anthropic oauth - can refresh", + platform: PlatformAnthropic, + accType: AccountTypeOAuth, + want: true, + }, + { + name: "anthropic api-key - cannot refresh", + platform: PlatformAnthropic, + accType: AccountTypeApiKey, + want: false, + }, + { + name: "openai oauth - cannot refresh", + platform: PlatformOpenAI, + accType: AccountTypeOAuth, + want: false, + }, + { + name: "gemini oauth - cannot refresh", + platform: PlatformGemini, + accType: AccountTypeOAuth, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: tt.platform, + Type: tt.accType, + } + + got := refresher.CanRefresh(account) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/deploy/README.md b/deploy/README.md index 5b127fc1..86f88f19 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -281,6 +281,30 @@ To change after installation: sudo systemctl restart sub2api ``` +#### Gemini OAuth Configuration + +If you need to use AI Studio OAuth for Gemini accounts, add the OAuth client credentials to the systemd service file: + +1. Edit the service file: + ```bash + sudo nano /etc/systemd/system/sub2api.service + ``` + +2. Add your OAuth credentials in the `[Service]` section (after the existing `Environment=` lines): + ```ini + Environment=GEMINI_OAUTH_CLIENT_ID=your-client-id.apps.googleusercontent.com + Environment=GEMINI_OAUTH_CLIENT_SECRET=GOCSPX-your-client-secret + ``` + +3. Reload and restart: + ```bash + sudo systemctl daemon-reload + sudo systemctl restart sub2api + ``` + +> **Note:** Code Assist OAuth does not require any configuration - it uses the built-in Gemini CLI client. +> See the [Gemini OAuth Configuration](#gemini-oauth-configuration) section above for detailed setup instructions. + #### Application Configuration The main config file is at `/etc/sub2api/config.yaml` (created by Setup Wizard). diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 4e2de67f..9e10ec54 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -121,8 +121,8 @@ services: timeout: 5s retries: 5 start_period: 10s - ports: - - 5433:5432 + # 注意:不暴露端口到宿主机,应用通过内部网络连接 + # 如需调试,可临时添加:ports: ["127.0.0.1:5433:5432"] # =========================================================================== # Redis Cache diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 366033ea..d239e97f 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1,5 +1,5 @@