From 0b9c4ae69e87289c15043bf8804ba43f8d3539cd Mon Sep 17 00:00:00 2001 From: shaw Date: Sat, 27 Dec 2025 20:42:00 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dclaude=20setup=20token?= =?UTF-8?q?=E6=8E=88=E6=9D=83=E6=95=88=E6=9C=9F=E7=9F=AD=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repository/claude_oauth_service.go | 7 ++- .../repository/claude_oauth_service_test.go | 46 +++++++++++++++---- backend/internal/service/oauth_service.go | 17 ++++--- 3 files changed, 53 insertions(+), 17 deletions(-) diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 75699712..f7ff2341 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -145,7 +145,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe return fullCode, nil } -func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { +func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { client := s.clientFactory(proxyURL) // Parse code which may contain state in format "authCode#state" @@ -168,6 +168,11 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["state"] = codeState } + // Setup token requires longer expiration (1 year) + if isSetupToken { + reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds + } + reqBodyJSON, _ := json.Marshal(reqBody) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index dd9c48b3..3295c222 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -191,12 +191,13 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { tests := []struct { - name string - handler http.HandlerFunc - code string - wantErr bool - wantResp *oauth.TokenResponse - validate func(captured requestCapture) + name string + handler http.HandlerFunc + code string + isSetupToken bool + wantErr bool + wantResp *oauth.TokenResponse + validate func(captured requestCapture) }{ { name: "sends_state_when_embedded", @@ -210,7 +211,8 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { Scope: "s", }) }, - code: "AUTH#STATE2", + code: "AUTH#STATE2", + isSetupToken: false, wantResp: &oauth.TokenResponse{ AccessToken: "at", RefreshToken: "rt", @@ -223,6 +225,29 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"]) require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"]) require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"]) + // Regular OAuth should not include expires_in + require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in") + }, + }, + { + name: "setup_token_includes_expires_in", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(oauth.TokenResponse{ + AccessToken: "at", + TokenType: "bearer", + ExpiresIn: 31536000, + }) + }, + code: "AUTH", + isSetupToken: true, + wantResp: &oauth.TokenResponse{ + AccessToken: "at", + }, + validate: func(captured requestCapture) { + // Setup token should include expires_in with 1 year value + require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"], + "setup token should include expires_in: 31536000") }, }, { @@ -231,8 +256,9 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte("bad request")) }, - code: "AUTH", - wantErr: true, + code: "AUTH", + isSetupToken: false, + wantErr: true, }, } @@ -254,7 +280,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { s.client = client s.client.tokenURL = s.srv.URL - resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "") + resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) if tt.wantErr { require.Error(s.T(), err) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index f4c149ac..0039cb44 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -20,7 +20,7 @@ type OpenAIOAuthClient interface { type ClaudeOAuthClient interface { GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) - ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) + ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) } @@ -142,8 +142,11 @@ func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInpu } } + // Determine if this is a setup token (scope is inference only) + isSetupToken := session.Scope == oauth.ScopeInference + // Exchange code for token - tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL) + tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL, isSetupToken) if err != nil { return nil, err } @@ -172,10 +175,12 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( } } - // Determine scope + // Determine scope and if this is a setup token scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) + isSetupToken := false if input.Scope == "inference" { scope = oauth.ScopeInference + isSetupToken = true } // Step 1: Get organization info using sessionKey @@ -203,7 +208,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( } // Step 4: Exchange code for token - tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL) + tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL, isSetupToken) if err != nil { return nil, fmt.Errorf("failed to exchange code: %w", err) } @@ -228,8 +233,8 @@ func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, org } // exchangeCodeForToken exchanges authorization code for tokens -func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) { - tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL) +func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*TokenInfo, error) { + tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL, isSetupToken) if err != nil { return nil, err }