perf(后端): 完成性能优化与连接池配置
新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
@@ -233,15 +233,11 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Impersonate: true,
|
||||
})
|
||||
}
|
||||
|
||||
func prefix(s string, n int) string {
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -23,20 +23,12 @@ func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to get default transport")
|
||||
}
|
||||
transport = transport.Clone()
|
||||
if proxyURL != "" {
|
||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
|
||||
@@ -3,67 +3,90 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// Key prefixes for independent slot keys
|
||||
// Format: concurrency:account:{accountID}:{requestID}
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// Format: concurrency:user:{userID}:{requestID}
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// Wait queue keeps counter format: concurrency:wait:{userID}
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
|
||||
// Slot TTL - each slot expires independently
|
||||
slotTTL = 5 * time.Minute
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript uses SCAN to count existing slots and creates new slot if under limit
|
||||
// KEYS[1] = pattern for SCAN (e.g., "concurrency:account:2:*")
|
||||
// KEYS[2] = full slot key (e.g., "concurrency:account:2:req_xxx")
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL in seconds
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local slotKey = KEYS[2]
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- Count existing slots using SCAN
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- Check if we can acquire a slot
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('SET', slotKey, '1', 'EX', ttl)
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript counts slots using SCAN
|
||||
// KEYS[1] = pattern for SCAN
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local pattern = KEYS[1]
|
||||
local cursor = "0"
|
||||
local count = 0
|
||||
repeat
|
||||
local result = redis.call('SCAN', cursor, 'MATCH', pattern, 'COUNT', 100)
|
||||
cursor = result[1]
|
||||
count = count + #result[2]
|
||||
until cursor == "0"
|
||||
return count
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
@@ -103,28 +126,29 @@ var (
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
}
|
||||
|
||||
func NewConcurrencyCache(rdb *redis.Client) service.ConcurrencyCache {
|
||||
return &concurrencyCache{rdb: rdb}
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, requestID)
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func accountSlotPattern(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64, requestID string) string {
|
||||
return fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, requestID)
|
||||
}
|
||||
|
||||
func userSlotPattern(userID int64) string {
|
||||
return fmt.Sprintf("%s%d:*", userSlotKeyPrefix, userID)
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
@@ -134,10 +158,9 @@ func waitQueueKey(userID int64) string {
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -145,13 +168,14 @@ func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
slotKey := accountSlotKey(accountID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
pattern := accountSlotPattern(accountID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -161,10 +185,9 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{pattern, slotKey}, maxConcurrency, int(slotTTL.Seconds())).Int()
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -172,13 +195,14 @@ func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, ma
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
slotKey := userSlotKey(userID, requestID)
|
||||
return c.rdb.Del(ctx, slotKey).Err()
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
pattern := userSlotPattern(userID)
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{pattern}).Int()
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -189,7 +213,7 @@ func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64)
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, int(slotTTL.Seconds())).Int()
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -14,6 +14,12 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
@@ -21,7 +27,7 @@ type ConcurrencyCacheSuite struct {
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb)
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
@@ -54,7 +60,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
@@ -62,7 +68,7 @@ func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
@@ -139,7 +145,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
@@ -147,7 +153,7 @@ func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
@@ -168,7 +174,7 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
|
||||
@@ -109,9 +109,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
|
||||
}
|
||||
|
||||
func createGeminiReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(60 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -76,11 +76,10 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
func createGeminiCliReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().SetTimeout(30 * time.Second)
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -17,10 +18,14 @@ type githubReleaseClient struct {
|
||||
}
|
||||
|
||||
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,8 +63,13 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
resp, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,65 +3,104 @@ package repository
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// httpUpstreamService is a generic HTTP upstream service that can be used for
|
||||
// making requests to any HTTP API (Claude, OpenAI, etc.) with optional proxy support.
|
||||
// httpUpstreamService 通用 HTTP 上游服务
|
||||
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
|
||||
//
|
||||
// 性能优化:
|
||||
// 1. 使用 sync.Map 缓存代理客户端实例,避免每次请求都创建新的 http.Client
|
||||
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
|
||||
// 3. 原实现每次请求都 new 一个 http.Client,导致连接无法复用
|
||||
type httpUpstreamService struct {
|
||||
// defaultClient: 无代理时使用的默认客户端(单例复用)
|
||||
defaultClient *http.Client
|
||||
cfg *config.Config
|
||||
// proxyClients: 按代理 URL 缓存的客户端池,避免重复创建
|
||||
proxyClients sync.Map
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewHTTPUpstream creates a new generic HTTP upstream service
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &httpUpstreamService{
|
||||
defaultClient: &http.Client{Transport: transport},
|
||||
defaultClient: &http.Client{Transport: buildUpstreamTransport(cfg, nil)},
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||
if proxyURL == "" {
|
||||
if strings.TrimSpace(proxyURL) == "" {
|
||||
return s.defaultClient.Do(req)
|
||||
}
|
||||
client := s.createProxyClient(proxyURL)
|
||||
client := s.getOrCreateClient(proxyURL)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||
// getOrCreateClient 获取或创建代理客户端
|
||||
// 性能优化:使用 sync.Map 实现无锁缓存,相同代理 URL 复用同一客户端
|
||||
// LoadOrStore 保证并发安全,避免重复创建
|
||||
func (s *httpUpstreamService) getOrCreateClient(proxyURL string) *http.Client {
|
||||
proxyURL = strings.TrimSpace(proxyURL)
|
||||
if proxyURL == "" {
|
||||
return s.defaultClient
|
||||
}
|
||||
// 优先从缓存获取,命中则直接返回
|
||||
if cached, ok := s.proxyClients.Load(proxyURL); ok {
|
||||
return cached.(*http.Client)
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return s.defaultClient
|
||||
}
|
||||
|
||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
// 创建新客户端并缓存,LoadOrStore 保证只有一个实例被存储
|
||||
client := &http.Client{Transport: buildUpstreamTransport(s.cfg, parsedURL)}
|
||||
actual, _ := s.proxyClients.LoadOrStore(proxyURL, client)
|
||||
return actual.(*http.Client)
|
||||
}
|
||||
|
||||
// buildUpstreamTransport 构建上游请求的 Transport
|
||||
// 使用配置文件中的连接池参数,支持生产环境调优
|
||||
func buildUpstreamTransport(cfg *config.Config, proxyURL *url.URL) *http.Transport {
|
||||
// 读取配置,使用合理的默认值
|
||||
maxIdleConns := cfg.Gateway.MaxIdleConns
|
||||
if maxIdleConns <= 0 {
|
||||
maxIdleConns = 240
|
||||
}
|
||||
maxIdleConnsPerHost := cfg.Gateway.MaxIdleConnsPerHost
|
||||
if maxIdleConnsPerHost <= 0 {
|
||||
maxIdleConnsPerHost = 120
|
||||
}
|
||||
maxConnsPerHost := cfg.Gateway.MaxConnsPerHost
|
||||
if maxConnsPerHost < 0 {
|
||||
maxConnsPerHost = 240
|
||||
}
|
||||
idleConnTimeout := time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
|
||||
if idleConnTimeout <= 0 {
|
||||
idleConnTimeout = 300 * time.Second
|
||||
}
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout <= 0 {
|
||||
responseHeaderTimeout = 300 * time.Second
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(parsedURL),
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: maxIdleConns, // 最大空闲连接总数
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost, // 每主机最大空闲连接
|
||||
MaxConnsPerHost: maxConnsPerHost, // 每主机最大连接数(含活跃)
|
||||
IdleConnTimeout: idleConnTimeout, // 空闲连接超时
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
|
||||
return &http.Client{Transport: transport}
|
||||
if proxyURL != nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
46
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
46
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var httpClientSink *http.Client
|
||||
|
||||
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销。
|
||||
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
|
||||
}
|
||||
upstream := NewHTTPUpstream(cfg)
|
||||
svc, ok := upstream.(*httpUpstreamService)
|
||||
if !ok {
|
||||
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
|
||||
}
|
||||
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
b.ReportAllocs()
|
||||
|
||||
b.Run("新建", func(b *testing.B) {
|
||||
parsedProxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析代理地址失败: %v", err)
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
httpClientSink = &http.Client{
|
||||
Transport: buildUpstreamTransport(cfg, parsedProxy),
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("复用", func(b *testing.B) {
|
||||
client := svc.getOrCreateClient(proxyURL)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
httpClientSink = client
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -40,13 +40,13 @@ func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
|
||||
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
|
||||
got := svc.createProxyClient("://bad-proxy-url")
|
||||
got := svc.getOrCreateClient("://bad-proxy-url")
|
||||
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
|
||||
}
|
||||
|
||||
|
||||
@@ -82,12 +82,8 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
if proxyURL != "" {
|
||||
client.SetProxyURL(proxyURL)
|
||||
}
|
||||
|
||||
return client
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -16,10 +17,14 @@ type pricingRemoteClient struct {
|
||||
}
|
||||
|
||||
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &pricingRemoteClient{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,18 +2,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||
@@ -27,14 +23,14 @@ type proxyProbeService struct {
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
transport, err := createProxyTransport(proxyURL)
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
InsecureSkipVerify: true,
|
||||
ProxyStrict: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 15 * time.Second,
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
@@ -78,31 +74,3 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
Country: ipInfo.Country,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
switch parsedURL.Scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsedURL)
|
||||
case "socks5":
|
||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
@@ -34,22 +34,16 @@ func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
|
||||
_, err := createProxyTransport("://bad")
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "invalid proxy URL")
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
|
||||
_, err := createProxyTransport("ftp://127.0.0.1:1")
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
|
||||
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
|
||||
require.NoError(s.T(), err, "createProxyTransport")
|
||||
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
|
||||
59
backend/internal/repository/req_client_pool.go
Normal file
59
backend/internal/repository/req_client_pool.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// reqClientOptions 定义 req 客户端的构建参数
|
||||
type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在每次 OAuth 刷新时都创建新的 req.Client:
|
||||
// 1. claude_oauth_service.go: 每次刷新创建新客户端
|
||||
// 2. openai_oauth_service.go: 每次刷新创建新客户端
|
||||
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
|
||||
//
|
||||
// 新实现使用 sync.Map 缓存客户端:
|
||||
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
|
||||
// 2. 复用底层连接池,减少 TLS 握手开销
|
||||
// 3. LoadOrStore 保证并发安全,避免重复创建
|
||||
var sharedReqClients sync.Map
|
||||
|
||||
// getSharedReqClient 获取共享的 req 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建
|
||||
func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
key := buildReqClientKey(opts)
|
||||
if cached, ok := sharedReqClients.Load(key); ok {
|
||||
return cached.(*req.Client)
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
if strings.TrimSpace(opts.ProxyURL) != "" {
|
||||
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
|
||||
}
|
||||
|
||||
actual, _ := sharedReqClients.LoadOrStore(key, client)
|
||||
return actual.(*req.Client)
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
@@ -20,11 +21,15 @@ type turnstileVerifier struct {
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
||||
}
|
||||
return &turnstileVerifier{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
verifyURL: turnstileVerifyURL,
|
||||
httpClient: sharedClient,
|
||||
verifyURL: turnstileVerifyURL,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -452,6 +452,161 @@ func (r *usageLogRepository) GetApiKeyStatsAggregated(ctx context.Context, apiKe
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现先查询所有日志记录,再在应用层循环计算统计值:
|
||||
// 1. 需要传输大量数据到应用层
|
||||
// 2. 应用层循环计算增加 CPU 和内存开销
|
||||
//
|
||||
// 新实现使用 SQL 聚合函数:
|
||||
// 1. 在数据库层完成 COUNT/SUM/AVG 计算
|
||||
// 2. 只返回单行聚合结果,大幅减少数据传输量
|
||||
// 3. 利用数据库索引优化聚合查询性能
|
||||
func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{accountID, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据
|
||||
// 性能优化:数据库层聚合计算,避免应用层循环统计
|
||||
func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE model = $1 AND created_at >= $2 AND created_at < $3
|
||||
`
|
||||
|
||||
var stats usagestats.UsageStats
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
query,
|
||||
[]any{modelName, startTime, endTime},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据
|
||||
// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计
|
||||
func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) {
|
||||
query := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM-DD') as date,
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY 1
|
||||
ORDER BY 1
|
||||
`
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
result = nil
|
||||
}
|
||||
}()
|
||||
|
||||
result = make([]map[string]any, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
date string
|
||||
totalRequests int64
|
||||
totalInputTokens int64
|
||||
totalOutputTokens int64
|
||||
totalCacheTokens int64
|
||||
totalCost float64
|
||||
totalActualCost float64
|
||||
avgDurationMs float64
|
||||
)
|
||||
if err = rows.Scan(
|
||||
&date,
|
||||
&totalRequests,
|
||||
&totalInputTokens,
|
||||
&totalOutputTokens,
|
||||
&totalCacheTokens,
|
||||
&totalCost,
|
||||
&totalActualCost,
|
||||
&avgDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": totalRequests,
|
||||
"total_input_tokens": totalInputTokens,
|
||||
"total_output_tokens": totalOutputTokens,
|
||||
"total_cache_tokens": totalCacheTokens,
|
||||
"total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens,
|
||||
"total_cost": totalCost,
|
||||
"total_actual_cost": totalActualCost,
|
||||
"average_duration_ms": avgDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
|
||||
// 性能优化:TTL 可配置,支持长时间运行的 LLM 请求场景
|
||||
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
|
||||
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
@@ -20,7 +29,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewApiKeyCache,
|
||||
NewConcurrencyCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
|
||||
Reference in New Issue
Block a user