Files
sub2api/backend/internal/service/oauth_service_test.go
yangjianbo fc8a39e0f5 test: 删除CI工作流,大幅提升后端单元测试覆盖率至50%+
删除因GitHub计费锁定而失败的CI工作流。
为6个核心Go源文件补充单元测试,全部达到50%以上覆盖率:
- response/response.go: 97.6%
- antigravity/oauth.go: 90.1%
- antigravity/client.go: 88.6% (新增27个HTTP客户端测试)
- geminicli/oauth.go: 91.8%
- service/oauth_service.go: 61.2%
- service/gemini_oauth_service.go: 51.9%

新增/增强8个测试文件,共计5600+行测试代码。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 09:07:58 +08:00

608 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"fmt"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
// --- mock: ClaudeOAuthClient ---
type mockClaudeOAuthClient struct {
getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error)
getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
if m.getOrgUUIDFunc != nil {
return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL)
}
panic("GetOrganizationUUID not implemented")
}
func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
if m.getAuthCodeFunc != nil {
return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
}
panic("GetAuthorizationCode not implemented")
}
func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
if m.exchangeCodeFunc != nil {
return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken)
}
panic("ExchangeCodeForToken not implemented")
}
func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if m.refreshTokenFunc != nil {
return m.refreshTokenFunc(ctx, refreshToken, proxyURL)
}
panic("RefreshToken not implemented")
}
// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) ---
type mockProxyRepoForOAuth struct {
getByIDFunc func(ctx context.Context, id int64) (*Proxy, error)
}
func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error {
panic("Create not implemented")
}
func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) {
if m.getByIDFunc != nil {
return m.getByIDFunc(ctx, id)
}
return nil, fmt.Errorf("proxy not found")
}
func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
panic("ListByIDs not implemented")
}
func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error {
panic("Update not implemented")
}
func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error {
panic("Delete not implemented")
}
func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
panic("List not implemented")
}
func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
panic("ListWithFilters not implemented")
}
func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("ListWithFiltersAndAccountCount not implemented")
}
func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) {
panic("ListActive not implemented")
}
func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
panic("ListActiveWithAccountCount not implemented")
}
func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("ExistsByHostPortAuth not implemented")
}
func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
panic("CountAccountsByProxyID not implemented")
}
func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("ListAccountSummariesByProxyID not implemented")
}
// =====================
// 测试用例
// =====================
func TestNewOAuthService(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{}
client := &mockClaudeOAuthClient{}
svc := NewOAuthService(proxyRepo, client)
if svc == nil {
t.Fatal("NewOAuthService 返回 nil")
}
if svc.proxyRepo != proxyRepo {
t.Fatal("proxyRepo 未正确设置")
}
if svc.oauthClient != client {
t.Fatal("oauthClient 未正确设置")
}
if svc.sessionStore == nil {
t.Fatal("sessionStore 应被自动初始化")
}
// 清理
svc.Stop()
}
func TestOAuthService_GenerateAuthURL(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
result, err := svc.GenerateAuthURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
if result == nil {
t.Fatal("GenerateAuthURL 返回 nil")
}
if result.AuthURL == "" {
t.Fatal("AuthURL 为空")
}
if result.SessionID == "" {
t.Fatal("SessionID 为空")
}
// 验证 session 已存储
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.Scope != oauth.ScopeOAuth {
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth)
}
}
func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
return &Proxy{
ID: 1,
Protocol: "http",
Host: "proxy.example.com",
Port: 8080,
}, nil
},
}
svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{})
defer svc.Stop()
proxyID := int64(1)
result, err := svc.GenerateAuthURL(context.Background(), &proxyID)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.ProxyURL != "http://proxy.example.com:8080" {
t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL)
}
}
func TestOAuthService_GenerateSetupTokenURL(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
}
if result == nil {
t.Fatal("GenerateSetupTokenURL 返回 nil")
}
// 验证 scope 是 inference
session, ok := svc.sessionStore.Get(result.SessionID)
if !ok {
t.Fatal("session 未在 sessionStore 中找到")
}
if session.Scope != oauth.ScopeInference {
t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference)
}
}
func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: "nonexistent-session",
Code: "test-code",
})
if err == nil {
t.Fatal("ExchangeCode 应返回错误session 不存在)")
}
if err.Error() != "session not found or expired" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_ExchangeCode_Success(t *testing.T) {
t.Parallel()
exchangeCalled := false
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
exchangeCalled = true
if code != "auth-code-123" {
t.Errorf("code 不匹配: got=%q", code)
}
if isSetupToken {
t.Error("isSetupToken 应为 falseScopeOAuth")
}
return &oauth.TokenResponse{
AccessToken: "access-token-abc",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "refresh-token-xyz",
Scope: oauth.ScopeOAuth,
Organization: &oauth.OrgInfo{UUID: "org-uuid-111"},
Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"},
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
// 先生成 URL 以创建 session
result, err := svc.GenerateAuthURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateAuthURL 返回错误: %v", err)
}
// 交换 code
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "auth-code-123",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if !exchangeCalled {
t.Fatal("ExchangeCodeForToken 未被调用")
}
if tokenInfo.AccessToken != "access-token-abc" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
if tokenInfo.TokenType != "Bearer" {
t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType)
}
if tokenInfo.RefreshToken != "refresh-token-xyz" {
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
}
if tokenInfo.OrgUUID != "org-uuid-111" {
t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID)
}
if tokenInfo.AccountUUID != "acc-uuid-222" {
t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID)
}
if tokenInfo.EmailAddress != "test@example.com" {
t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress)
}
if tokenInfo.ExpiresIn != 3600 {
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
}
if tokenInfo.ExpiresAt == 0 {
t.Fatal("ExpiresAt 不应为 0")
}
// 验证 session 已被删除
_, ok := svc.sessionStore.Get(result.SessionID)
if ok {
t.Fatal("session 应在交换成功后被删除")
}
}
func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
if !isSetupToken {
t.Error("isSetupToken 应为 trueScopeInference")
}
return &oauth.TokenResponse{
AccessToken: "setup-token",
TokenType: "Bearer",
ExpiresIn: 3600,
Scope: oauth.ScopeInference,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
// 使用 SetupToken URLinference scope
result, err := svc.GenerateSetupTokenURL(context.Background(), nil)
if err != nil {
t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err)
}
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "setup-code",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if tokenInfo.AccessToken != "setup-token" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
}
func TestOAuthService_ExchangeCode_ClientError(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("upstream error: invalid code")
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
result, _ := svc.GenerateAuthURL(context.Background(), nil)
_, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "bad-code",
})
if err == nil {
t.Fatal("ExchangeCode 应返回错误")
}
if err.Error() != "upstream error: invalid code" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_RefreshToken(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if refreshToken != "my-refresh-token" {
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
}
if proxyURL != "" {
t.Errorf("proxyURL 应为空: got=%q", proxyURL)
}
return &oauth.TokenResponse{
AccessToken: "new-access-token",
TokenType: "Bearer",
ExpiresIn: 7200,
RefreshToken: "new-refresh-token",
Scope: oauth.ScopeOAuth,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "")
if err != nil {
t.Fatalf("RefreshToken 返回错误: %v", err)
}
if tokenInfo.AccessToken != "new-access-token" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
if tokenInfo.RefreshToken != "new-refresh-token" {
t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken)
}
if tokenInfo.ExpiresIn != 7200 {
t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn)
}
if tokenInfo.ExpiresAt == 0 {
t.Fatal("ExpiresAt 不应为 0")
}
}
func TestOAuthService_RefreshToken_Error(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("invalid_grant: token expired")
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
_, err := svc.RefreshToken(context.Background(), "expired-token", "")
if err == nil {
t.Fatal("RefreshToken 应返回错误")
}
}
func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
// 无 refresh_token 的账号
account := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err == nil {
t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token")
}
if err.Error() != "no refresh token available" {
t.Fatalf("错误信息不匹配: got=%q", err.Error())
}
}
func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
defer svc.Stop()
account := &Account{
ID: 2,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "some-token",
"refresh_token": "",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err == nil {
t.Fatal("RefreshAccountToken 应返回错误refresh_token 为空)")
}
}
func TestOAuthService_RefreshAccountToken_Success(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if refreshToken != "account-refresh-token" {
t.Errorf("refreshToken 不匹配: got=%q", refreshToken)
}
return &oauth.TokenResponse{
AccessToken: "refreshed-access",
TokenType: "Bearer",
ExpiresIn: 3600,
RefreshToken: "new-refresh",
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
account := &Account{
ID: 3,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access",
"refresh_token": "account-refresh-token",
},
}
tokenInfo, err := svc.RefreshAccountToken(context.Background(), account)
if err != nil {
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
}
if tokenInfo.AccessToken != "refreshed-access" {
t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken)
}
}
func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) {
t.Parallel()
proxyRepo := &mockProxyRepoForOAuth{
getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) {
return &Proxy{
Protocol: "socks5",
Host: "socks.example.com",
Port: 1080,
Username: "user",
Password: "pass",
}, nil
},
}
client := &mockClaudeOAuthClient{
refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
if proxyURL != "socks5://user:pass@socks.example.com:1080" {
t.Errorf("proxyURL 不匹配: got=%q", proxyURL)
}
return &oauth.TokenResponse{
AccessToken: "refreshed",
ExpiresIn: 3600,
}, nil
},
}
svc := NewOAuthService(proxyRepo, client)
defer svc.Stop()
proxyID := int64(10)
account := &Account{
ID: 4,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
ProxyID: &proxyID,
Credentials: map[string]any{
"refresh_token": "rt-with-proxy",
},
}
_, err := svc.RefreshAccountToken(context.Background(), account)
if err != nil {
t.Fatalf("RefreshAccountToken 返回错误: %v", err)
}
}
func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) {
t.Parallel()
client := &mockClaudeOAuthClient{
exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
return &oauth.TokenResponse{
AccessToken: "token-no-org",
TokenType: "Bearer",
ExpiresIn: 3600,
Organization: nil,
Account: nil,
}, nil
},
}
svc := NewOAuthService(&mockProxyRepoForOAuth{}, client)
defer svc.Stop()
result, _ := svc.GenerateAuthURL(context.Background(), nil)
tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{
SessionID: result.SessionID,
Code: "code",
})
if err != nil {
t.Fatalf("ExchangeCode 返回错误: %v", err)
}
if tokenInfo.OrgUUID != "" {
t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID)
}
if tokenInfo.AccountUUID != "" {
t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID)
}
}
func TestOAuthService_Stop_NoPanic(t *testing.T) {
t.Parallel()
svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{})
// 调用 Stop 不应 panic
svc.Stop()
// 多次调用也不应 panic
svc.Stop()
}