feat(认证): 启用 OpenAI OAuth HTTP/2 并修复清理任务 lint
为共享 req 客户端增加 HTTP/2 选项与缓存隔离 OpenAI OAuth 超时提升到 120s,并按协议控制强制 新增客户端池与 OAuth 客户端单测覆盖 修复 usage cleanup 相关 errcheck/ineffassign/staticcheck 并统一格式 测试: make test
This commit is contained in:
@@ -3,8 +3,8 @@ package admin
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||||
@@ -21,7 +22,7 @@ type openaiOAuthService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
client := createOpenAIReqClient(proxyURL)
|
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||||
|
|
||||||
if redirectURI == "" {
|
if redirectURI == "" {
|
||||||
redirectURI = openai.DefaultRedirectURI
|
redirectURI = openai.DefaultRedirectURI
|
||||||
@@ -54,7 +55,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||||
client := createOpenAIReqClient(proxyURL)
|
client := createOpenAIReqClient(s.tokenURL, proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
formData := url.Values{}
|
||||||
formData.Set("grant_type", "refresh_token")
|
formData.Set("grant_type", "refresh_token")
|
||||||
@@ -81,9 +82,14 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
|||||||
return &tokenResp, nil
|
return &tokenResp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
func createOpenAIReqClient(tokenURL, proxyURL string) *req.Client {
|
||||||
|
forceHTTP2 := false
|
||||||
|
if parsedURL, err := url.Parse(tokenURL); err == nil {
|
||||||
|
forceHTTP2 = strings.EqualFold(parsedURL.Scheme, "https")
|
||||||
|
}
|
||||||
return getSharedReqClient(reqClientOptions{
|
return getSharedReqClient(reqClientOptions{
|
||||||
ProxyURL: proxyURL,
|
ProxyURL: proxyURL,
|
||||||
Timeout: 60 * time.Second,
|
Timeout: 120 * time.Second,
|
||||||
|
ForceHTTP2: forceHTTP2,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,6 +244,13 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
|||||||
require.ErrorContains(s.T(), err, "status 401")
|
require.ErrorContains(s.T(), err, "status 401")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||||
|
client := NewOpenAIOAuthClient()
|
||||||
|
svc, ok := client.(*openaiOAuthService)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type reqClientOptions struct {
|
|||||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||||
Timeout time.Duration // 请求超时时间
|
Timeout time.Duration // 请求超时时间
|
||||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||||
|
ForceHTTP2 bool // 是否强制使用 HTTP/2
|
||||||
}
|
}
|
||||||
|
|
||||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||||
@@ -41,6 +42,9 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := req.C().SetTimeout(opts.Timeout)
|
client := req.C().SetTimeout(opts.Timeout)
|
||||||
|
if opts.ForceHTTP2 {
|
||||||
|
client = client.EnableForceHTTP2()
|
||||||
|
}
|
||||||
if opts.Impersonate {
|
if opts.Impersonate {
|
||||||
client = client.ImpersonateChrome()
|
client = client.ImpersonateChrome()
|
||||||
}
|
}
|
||||||
@@ -56,9 +60,10 @@ func getSharedReqClient(opts reqClientOptions) *req.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildReqClientKey(opts reqClientOptions) string {
|
func buildReqClientKey(opts reqClientOptions) string {
|
||||||
return fmt.Sprintf("%s|%s|%t",
|
return fmt.Sprintf("%s|%s|%t|%t",
|
||||||
strings.TrimSpace(opts.ProxyURL),
|
strings.TrimSpace(opts.ProxyURL),
|
||||||
opts.Timeout.String(),
|
opts.Timeout.String(),
|
||||||
opts.Impersonate,
|
opts.Impersonate,
|
||||||
|
opts.ForceHTTP2,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
102
backend/internal/repository/req_client_pool_test.go
Normal file
102
backend/internal/repository/req_client_pool_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func forceHTTPVersion(t *testing.T, client *req.Client) string {
|
||||||
|
t.Helper()
|
||||||
|
transport := client.GetTransport()
|
||||||
|
field := reflect.ValueOf(transport).Elem().FieldByName("forceHttpVersion")
|
||||||
|
require.True(t, field.IsValid(), "forceHttpVersion field not found")
|
||||||
|
require.True(t, field.CanAddr(), "forceHttpVersion field not addressable")
|
||||||
|
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
base := reqClientOptions{
|
||||||
|
ProxyURL: "http://proxy.local:8080",
|
||||||
|
Timeout: time.Second,
|
||||||
|
}
|
||||||
|
clientDefault := getSharedReqClient(base)
|
||||||
|
|
||||||
|
force := base
|
||||||
|
force.ForceHTTP2 = true
|
||||||
|
clientForce := getSharedReqClient(force)
|
||||||
|
|
||||||
|
require.NotSame(t, clientDefault, clientForce)
|
||||||
|
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: "http://proxy.local:8080",
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
first := getSharedReqClient(opts)
|
||||||
|
second := getSharedReqClient(opts)
|
||||||
|
require.Same(t, first, second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: " http://proxy.local:8080 ",
|
||||||
|
Timeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
key := buildReqClientKey(opts)
|
||||||
|
sharedReqClients.Store(key, "invalid")
|
||||||
|
|
||||||
|
client := getSharedReqClient(opts)
|
||||||
|
|
||||||
|
require.NotNil(t, client)
|
||||||
|
loaded, ok := sharedReqClients.Load(key)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.IsType(t, "invalid", loaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
opts := reqClientOptions{
|
||||||
|
ProxyURL: " http://proxy.local:8080 ",
|
||||||
|
Timeout: 4 * time.Second,
|
||||||
|
Impersonate: true,
|
||||||
|
}
|
||||||
|
client := getSharedReqClient(opts)
|
||||||
|
|
||||||
|
require.NotNil(t, client)
|
||||||
|
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOpenAIReqClient_ForceHTTP2Enabled(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||||
|
require.Equal(t, "2", forceHTTPVersion(t, client))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOpenAIReqClient_ForceHTTP2DisabledForHTTP(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createOpenAIReqClient("http://localhost/oauth/token", "http://proxy.local:8080")
|
||||||
|
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createOpenAIReqClient("https://auth.openai.com/oauth/token", "http://proxy.local:8080")
|
||||||
|
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
|
||||||
|
sharedReqClients = sync.Map{}
|
||||||
|
client := createGeminiReqClient("http://proxy.local:8080")
|
||||||
|
require.Equal(t, "", forceHTTPVersion(t, client))
|
||||||
|
}
|
||||||
@@ -64,7 +64,9 @@ func (r *usageCleanupRepository) ListTasks(ctx context.Context, params paginatio
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
tasks := make([]service.UsageCleanupTask, 0)
|
tasks := make([]service.UsageCleanupTask, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -295,7 +297,9 @@ func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filte
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() {
|
||||||
|
_ = rows.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
var deleted int64
|
var deleted int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@@ -357,7 +361,6 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
|
|||||||
if filters.BillingType != nil {
|
if filters.BillingType != nil {
|
||||||
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
|
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
|
||||||
args = append(args, *filters.BillingType)
|
args = append(args, *filters.BillingType)
|
||||||
idx++
|
|
||||||
}
|
}
|
||||||
return strings.Join(conditions, " AND "), args
|
return strings.Join(conditions, " AND "), args
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ var (
|
|||||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||||
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
|
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,9 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageCleanupService) runOnce() {
|
func (s *UsageCleanupService) runOnce() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||||
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
|
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
|
||||||
return
|
return
|
||||||
@@ -158,7 +161,7 @@ func (s *UsageCleanupService) runOnce() {
|
|||||||
defer atomic.StoreInt32(&s.running, 0)
|
defer atomic.StoreInt32(&s.running, 0)
|
||||||
|
|
||||||
parent := context.Background()
|
parent := context.Background()
|
||||||
if s != nil && s.workerCtx != nil {
|
if s.workerCtx != nil {
|
||||||
parent = s.workerCtx
|
parent = s.workerCtx
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
|
ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
|
||||||
|
|||||||
@@ -266,9 +266,11 @@ func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
||||||
|
start := time.Now()
|
||||||
|
end := start.Add(2 * time.Hour)
|
||||||
repo := &cleanupRepoStub{
|
repo := &cleanupRepoStub{
|
||||||
claimQueue: []*UsageCleanupTask{
|
claimQueue: []*UsageCleanupTask{
|
||||||
{ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
|
{ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
|
||||||
},
|
},
|
||||||
deleteQueue: []cleanupDeleteResponse{
|
deleteQueue: []cleanupDeleteResponse{
|
||||||
{deleted: 2},
|
{deleted: 2},
|
||||||
@@ -284,6 +286,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
|||||||
repo.mu.Lock()
|
repo.mu.Lock()
|
||||||
defer repo.mu.Unlock()
|
defer repo.mu.Unlock()
|
||||||
require.Len(t, repo.deleteCalls, 3)
|
require.Len(t, repo.deleteCalls, 3)
|
||||||
|
require.Equal(t, 2, repo.deleteCalls[0].limit)
|
||||||
|
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
|
||||||
|
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
|
||||||
require.Len(t, repo.markSucceeded, 1)
|
require.Len(t, repo.markSucceeded, 1)
|
||||||
require.Empty(t, repo.markFailed)
|
require.Empty(t, repo.markFailed)
|
||||||
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
||||||
|
|||||||
Reference in New Issue
Block a user