From 73e6b160f8772b7e8ce279cdfe2acf239cba53da Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Mon, 19 Jan 2026 19:50:57 +0800 Subject: [PATCH] =?UTF-8?q?feat(=E8=AE=A4=E8=AF=81):=20=E5=90=AF=E7=94=A8?= =?UTF-8?q?=20OpenAI=20OAuth=20HTTP/2=20=E5=B9=B6=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=B8=85=E7=90=86=E4=BB=BB=E5=8A=A1=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为共享 req 客户端增加 HTTP/2 选项与缓存隔离 OpenAI OAuth 超时提升到 120s,并按协议控制强制 新增客户端池与 OAuth 客户端单测覆盖 修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式 测试: make test --- .../admin/usage_cleanup_handler_test.go | 2 +- .../repository/openai_oauth_service.go | 16 ++- .../repository/openai_oauth_service_test.go | 7 ++ .../internal/repository/req_client_pool.go | 7 +- .../repository/req_client_pool_test.go | 102 ++++++++++++++++++ .../internal/repository/usage_cleanup_repo.go | 9 +- .../service/dashboard_aggregation_service.go | 2 +- .../internal/service/usage_cleanup_service.go | 5 +- .../service/usage_cleanup_service_test.go | 7 +- 9 files changed, 144 insertions(+), 13 deletions(-) create mode 100644 backend/internal/repository/req_client_pool_test.go diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go index d8684c39..ed1c7cc2 100644 --- a/backend/internal/handler/admin/usage_cleanup_handler_test.go +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -3,8 +3,8 @@ package admin import ( "bytes" "context" - "encoding/json" "database/sql" + "encoding/json" "errors" "net/http" "net/http/httptest" diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 07d57410..b7f3606f 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/url" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -21,7 +22,7 @@ type openaiOAuthService struct { } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client := createOpenAIReqClient(s.tokenURL, proxyURL) if redirectURI == "" { redirectURI = openai.DefaultRedirectURI @@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + client := createOpenAIReqClient(s.tokenURL, proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client { + forceHTTP2 := false + if parsedURL, err := url.Parse(tokenURL); err == nil { + forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https") + } return getSharedReqClient(reqClientOptions{ - ProxyURL: proxyURL, - Timeout: 60 * time.Second, + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + ForceHTTP2: forceHTTP2, }) } diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 51142306..f9df08c8 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() { require.ErrorContains(s.T(), err, "status 401") } +func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) { + client := NewOpenAIOAuthClient() + svc, ok := client.(*openaiOAuthService) + require.True(t, ok) + require.Equal(t, openai.TokenURL, svc.tokenURL) +} + func TestOpenAIOAuthServiceSuite(t *testing.T) { suite.Run(t, new(OpenAIOAuthServiceSuite)) } diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index b23462a4..af71a7ee 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -14,6 +14,7 @@ type reqClientOptions struct { ProxyURL string // 代理 URL(支持 http/https/socks5) Timeout time.Duration // 请求超时时间 Impersonate bool // 是否模拟 Chrome 浏览器指纹 + ForceHTTP2 bool // 是否强制使用 HTTP/2 } // sharedReqClients 存储按配置参数缓存的 req 客户端实例 @@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { } client := req.C().SetTimeout(opts.Timeout) + if opts.ForceHTTP2 { + client = client.EnableForceHTTP2() + } if opts.Impersonate { client = client.ImpersonateChrome() } @@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { } func buildReqClientKey(opts reqClientOptions) string { - return fmt.Sprintf("%s|%s|%t", + return fmt.Sprintf("%s|%s|%t|%t", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.Impersonate, + opts.ForceHTTP2, ) } diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go new file mode 100644 index 00000000..cf7e8bd0 --- /dev/null +++ b/backend/internal/repository/req_client_pool_test.go @@ -0,0 +1,102 @@ +package repository + +import ( + "reflect" + "sync" + "testing" + "time" + "unsafe" + + "github.com/imroc/req/v3" + "github.com/stretchr/testify/require" +) + +func forceHTTPVersion(t *testing.T, client *req.Client) string { + t.Helper() + transport := client.GetTransport() + field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion") + require.True(t, field.IsValid(), "forceHttpVersion field not found") + require.True(t, field.CanAddr(), "forceHttpVersion field not addressable") + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String() +} + +func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { + sharedReqClients = sync.Map{} + base := reqClientOptions{ + ProxyURL: "http://proxy.local:8080", + Timeout: time.Second, + } + clientDefault := getSharedReqClient(base) + + force := base + force.ForceHTTP2 = true + clientForce := getSharedReqClient(force) + + require.NotSame(t, clientDefault, clientForce) + require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) +} + +func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://proxy.local:8080", + Timeout: 2 * time.Second, + } + first := getSharedReqClient(opts) + second := getSharedReqClient(opts) + require.Same(t, first, second) +} + +func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: " http://proxy.local:8080 ", + Timeout: 3 * time.Second, + } + key := buildReqClientKey(opts) + sharedReqClients.Store(key, "invalid") + + client := getSharedReqClient(opts) + + require.NotNil(t, client) + loaded, ok := sharedReqClients.Load(key) + require.True(t, ok) + require.IsType(t, "invalid", loaded) +} + +func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: " http://proxy.local:8080 ", + Timeout: 4 * time.Second, + Impersonate: true, + } + client := getSharedReqClient(opts) + + require.NotNil(t, client) + require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) +} + +func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) { + sharedReqClients = sync.Map{} + client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080") + require.Equal(t, "2", forceHTTPVersion(t, client)) +} + +func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) { + sharedReqClients = sync.Map{} + client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080") + require.Equal(t, "", forceHTTPVersion(t, client)) +} + +func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { + sharedReqClients = sync.Map{} + client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080") + require.Equal(t, 120*time.Second, client.GetClient().Timeout) +} + +func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { + sharedReqClients = sync.Map{} + client := createGeminiReqClient("http://proxy.local:8080") + require.Equal(t, "", forceHTTPVersion(t, client)) +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go index b703cc9f..c5da2776 100644 --- a/backend/internal/repository/usage_cleanup_repo.go +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -64,7 +64,9 @@ func (r *usageCleanupRepository) ListTasks(ctx context.Context, params paginatio if err != nil { return nil, nil, err } - defer rows.Close() + defer func() { + _ = rows.Close() + }() tasks := make([]service.UsageCleanupTask, 0) for rows.Next() { @@ -295,7 +297,9 @@ func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filte if err != nil { return 0, err } - defer rows.Close() + defer func() { + _ = rows.Close() + }() var deleted int64 for rows.Next() { @@ -357,7 +361,6 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx)) args = append(args, *filters.BillingType) - idx++ } return strings.Join(conditions, " AND "), args } diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 8f7e8144..10c68868 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -20,7 +20,7 @@ var ( // ErrDashboardBackfillDisabled 当配置禁用回填时返回。 ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 - ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") + ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") errDashboardAggregationRunning = errors.New("聚合作业正在运行") ) diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go index 8ca02cfc..1b0fde37 100644 --- a/backend/internal/service/usage_cleanup_service.go +++ b/backend/internal/service/usage_cleanup_service.go @@ -151,6 +151,9 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean } func (s *UsageCleanupService) runOnce() { + if s == nil { + return + } if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { log.Printf("[UsageCleanup] run_once skipped: already_running=true") return @@ -158,7 +161,7 @@ func (s *UsageCleanupService) runOnce() { defer atomic.StoreInt32(&s.running, 0) parent := context.Background() - if s != nil && s.workerCtx != nil { + if s.workerCtx != nil { parent = s.workerCtx } ctx, cancel := context.WithTimeout(parent, s.taskTimeout()) diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 37d3eb19..1d5cc681 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -266,9 +266,11 @@ func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) { } func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) { + start := time.Now() + end := start.Add(2 * time.Hour) repo := &cleanupRepoStub{ claimQueue: []*UsageCleanupTask{ - {ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}}, + {ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}}, }, deleteQueue: []cleanupDeleteResponse{ {deleted: 2}, @@ -284,6 +286,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) { repo.mu.Lock() defer repo.mu.Unlock() require.Len(t, repo.deleteCalls, 3) + require.Equal(t, 2, repo.deleteCalls[0].limit) + require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start)) + require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end)) require.Len(t, repo.markSucceeded, 1) require.Empty(t, repo.markFailed) require.Equal(t, int64(5), repo.markSucceeded[0].taskID)