feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint

为共享 req 客户端增加 HTTP/2 选项与缓存隔离
OpenAI OAuth 超时提升到 120s,并按协议控制强制
新增客户端池与 OAuth 客户端单测覆盖
修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式

测试: make test
This commit is contained in:
yangjianbo
2026-01-19 19:50:57 +08:00
parent ef5a41057f
commit 73e6b160f8
9 changed files with 144 additions and 13 deletions

View File

@@ -3,8 +3,8 @@ package admin
import (
"bytes"
"context"
"encoding/json"
"database/sql"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
client := createOpenAIReqClient(s.tokenURL, proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
forceHTTP2 := false
if parsedURL, err := url.Parse(tokenURL); err == nil {
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
}
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
ForceHTTP2: forceHTTP2,
})
}

View File

@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
require.ErrorContains(s.T(), err, "status 401")
}
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
client := NewOpenAIOAuthClient()
svc, ok := client.(*openaiOAuthService)
require.True(t, ok)
require.Equal(t, openai.TokenURL, svc.tokenURL)
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -14,6 +14,7 @@ type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
ForceHTTP2 bool // 是否强制使用 HTTP/2
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
client := req.C().SetTimeout(opts.Timeout)
if opts.ForceHTTP2 {
client = client.EnableForceHTTP2()
}
if opts.Impersonate {
client = client.ImpersonateChrome()
}
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
return fmt.Sprintf("%s|%s|%t|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
opts.ForceHTTP2,
)
}

View File

@@ -0,0 +1,102 @@
package repository
import (
"reflect"
"sync"
"testing"
"time"
"unsafe"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
)
func forceHTTPVersion(t *testing.T, client *req.Client) string {
t.Helper()
transport := client.GetTransport()
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
require.True(t, field.IsValid(), "forceHttpVersion field not found")
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
}
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
sharedReqClients = sync.Map{}
base := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
}
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
first := getSharedReqClient(opts)
second := getSharedReqClient(opts)
require.Same(t, first, second)
}
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 3 * time.Second,
}
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
require.True(t, ok)
require.IsType(t, "invalid", loaded)
}
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
sharedReqClients = sync.Map{}
opts := reqClientOptions{
ProxyURL: " http://proxy.local:8080 ",
Timeout: 4 * time.Second,
Impersonate: true,
}
client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, "2", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}

View File

@@ -64,7 +64,9 @@ func (r *usageCleanupRepository) ListTasks(ctx context.Context, params paginatio
if err != nil {
return nil, nil, err
}
defer rows.Close()
defer func() {
_ = rows.Close()
}()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
@@ -295,7 +297,9 @@ func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filte
if err != nil {
return 0, err
}
defer rows.Close()
defer func() {
_ = rows.Close()
}()
var deleted int64
for rows.Next() {
@@ -357,7 +361,6 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
idx++
}
return strings.Join(conditions, " AND "), args
}

View File

@@ -20,7 +20,7 @@ var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)

View File

@@ -151,6 +151,9 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean
}
func (s *UsageCleanupService) runOnce() {
if s == nil {
return
}
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
return
@@ -158,7 +161,7 @@ func (s *UsageCleanupService) runOnce() {
defer atomic.StoreInt32(&s.running, 0)
parent := context.Background()
if s != nil && s.workerCtx != nil {
if s.workerCtx != nil {
parent = s.workerCtx
}
ctx, cancel := context.WithTimeout(parent, s.taskTimeout())

View File

@@ -266,9 +266,11 @@ func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
}
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
start := time.Now()
end := start.Add(2 * time.Hour)
repo := &cleanupRepoStub{
claimQueue: []*UsageCleanupTask{
{ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
{ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
},
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
@@ -284,6 +286,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.deleteCalls, 3)
require.Equal(t, 2, repo.deleteCalls[0].limit)
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
require.Len(t, repo.markSucceeded, 1)
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)