perf(后端): 完成性能优化与连接池配置
新增 DB/Redis 连接池配置与校验,并补充单测 网关请求体大小限制与 413 处理 HTTP/req 客户端池化并调整上游连接池默认值 并发槽位改为 ZSET+Lua 与指数退避 用量统计改 SQL 聚合并新增索引迁移 计费缓存写入改工作池并补测试/基准 测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
@@ -52,6 +52,9 @@ type UsageLogRepository interface {
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -27,6 +28,45 @@ type subscriptionCacheData struct {
|
||||
Version int64
|
||||
}
|
||||
|
||||
// 缓存写入任务类型
|
||||
type cacheWriteKind int
|
||||
|
||||
const (
|
||||
cacheWriteSetBalance cacheWriteKind = iota
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||||
// 3. goroutine 创建/销毁带来额外开销
|
||||
//
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时丢弃任务(缓存最终一致性可接受)
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
@@ -34,16 +74,81 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||||
case cacheWriteSetSubscription:
|
||||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -70,11 +175,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setBalanceCache(cacheCtx, userID, balance)
|
||||
}()
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
@@ -98,7 +203,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -106,6 +211,15 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
@@ -141,11 +255,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
|
||||
}()
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
subscriptionData: data,
|
||||
})
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -199,7 +314,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -207,6 +322,16 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.cache == nil {
|
||||
|
||||
75
backend/internal/service/billing_cache_service_test.go
Normal file
75
backend/internal/service/billing_cache_service_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheWorkerStub struct {
|
||||
balanceUpdates int64
|
||||
subscriptionUpdates int64
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < cacheWriteBufferSize*2; i++ {
|
||||
svc.QueueDeductBalance(1, 1)
|
||||
}
|
||||
require.Less(t, time.Since(start), 2*time.Second)
|
||||
|
||||
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.balanceUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -9,22 +9,22 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||||
type ConcurrencyCache interface {
|
||||
// Account slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:account:{accountID}:{requestID}
|
||||
// 账号槽位管理
|
||||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// User slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:user:{userID}:{requestID}
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue - uses counter with TTL set only on creation
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
|
||||
70
backend/internal/service/gateway_request.go
Normal file
70
backend/internal/service/gateway_request.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
|
||||
// 1. gateway_handler.go 解析获取 model 和 stream
|
||||
// 2. gateway_service.go 再次解析获取 system、messages、metadata
|
||||
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
|
||||
//
|
||||
// 新实现一次解析,多处复用:
|
||||
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
}
|
||||
|
||||
if rawModel, exists := req["model"]; exists {
|
||||
model, ok := rawModel.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = model
|
||||
}
|
||||
if rawStream, exists := req["stream"]; exists {
|
||||
stream, ok := rawStream.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = stream
|
||||
}
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
if system, ok := req["system"]; ok && system != nil {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
38
backend/internal/service/gateway_request_test.go
Normal file
38
backend/internal/service/gateway_request_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.NotNil(t, parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.False(t, parsed.HasSystem)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -33,7 +32,10 @@ const (
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
@@ -141,40 +143,36 @@ func NewGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
// GenerateSessionHash 从预解析请求计算粘性会话 hash
|
||||
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 提取带cache_control: {type: "ephemeral"}的内容
|
||||
cacheableContent := s.extractCacheableContent(req)
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用system内容
|
||||
if system := req["system"]; system != nil {
|
||||
systemText := s.extractTextFromSystem(system)
|
||||
// 3. Fallback: 使用 system 内容
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
return s.hashContent(systemText)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息
|
||||
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||
// 4. 最后 fallback: 使用第一条消息
|
||||
if len(parsed.Messages) > 0 {
|
||||
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
@@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
var content string
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 检查system中的cacheable内容
|
||||
if system, ok := req["system"].([]any); ok {
|
||||
var builder strings.Builder
|
||||
|
||||
// 检查 system 中的 cacheable 内容
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, part := range system {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
content += text
|
||||
builder.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
systemText := builder.String()
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
// 找到cacheable内容,提取第一条消息的文本
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
// 检查 messages 中的 cacheable 内容
|
||||
for _, msg := range parsed.Messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
return systemText
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||
@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取model和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
if !gjson.GetBytes(body, "system").Exists() {
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
|
||||
if !parsed.HasSystem {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
req.Model = mappedModel
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
if req.Stream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: req.Stream,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
|
||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||
// 如果客户端传了anthropic-beta
|
||||
if clientBetaHeader != "" {
|
||||
// 已包含oauth beta则直接返回
|
||||
@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
}
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
var modelID string
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(body, &reqMap) == nil {
|
||||
if m, ok := reqMap["model"].(string); ok {
|
||||
modelID = m
|
||||
}
|
||||
}
|
||||
|
||||
// haiku模型不需要claude-code beta
|
||||
// haiku 模型不需要 claude-code beta
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.HaikuBetaHeader
|
||||
}
|
||||
@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
|
||||
log.Printf("Update subscription cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update balance cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
if parsed == nil {
|
||||
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
if account.Platform == PlatformAntigravity {
|
||||
@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchmarkStringSink string
|
||||
|
||||
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
|
||||
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
benchmarkStringSink = svc.GenerateSessionHash(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
|
||||
func BenchmarkExtractCacheableContent_System(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
req := buildSystemCacheableRequest(12)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkStringSink = svc.extractCacheableContent(req)
|
||||
}
|
||||
}
|
||||
|
||||
func buildSystemCacheableRequest(parts int) *ParsedRequest {
|
||||
systemParts := make([]any, 0, parts)
|
||||
for i := 0; i < parts; i++ {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"text": "system_part_" + strconv.Itoa(i),
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
return &ParsedRequest{
|
||||
System: systemParts,
|
||||
HasSystem: true,
|
||||
}
|
||||
}
|
||||
@@ -921,7 +921,10 @@ func sleepGeminiBackoff(attempt int) {
|
||||
time.Sleep(sleepFor)
|
||||
}
|
||||
|
||||
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
var (
|
||||
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
)
|
||||
|
||||
func sanitizeUpstreamErrorMessage(msg string) string {
|
||||
if msg == "" {
|
||||
@@ -1925,7 +1928,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||
}
|
||||
|
||||
// Match "Please retry in Xs"
|
||||
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
matches := retryInRegex.FindStringSubmatch(string(body))
|
||||
if len(matches) == 2 {
|
||||
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type GeminiOAuthService struct {
|
||||
@@ -497,11 +497,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
|
||||
}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
@@ -768,20 +768,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
|
||||
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 正则匹配日期后缀 (如 -20251222)
|
||||
datePattern := regexp.MustCompile(`-\d{8}$`)
|
||||
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, datePattern)
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
|
||||
for _, variant := range variants {
|
||||
if pricing, ok := s.pricingData[variant]; ok {
|
||||
@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *
|
||||
|
||||
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
|
||||
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的
|
||||
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
|
||||
// 3. 同时去掉日期后再提取基础版本号
|
||||
if withoutDate != model {
|
||||
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get account stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByModel 获取模型的使用统计
|
||||
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get model stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
@@ -209,54 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get daily stats: %w", err)
|
||||
}
|
||||
|
||||
// 按日期分组统计
|
||||
dailyStats := make(map[string]*UsageStats)
|
||||
for _, log := range logs {
|
||||
dateKey := log.CreatedAt.Format("2006-01-02")
|
||||
if _, exists := dailyStats[dateKey]; !exists {
|
||||
dailyStats[dateKey] = &UsageStats{}
|
||||
}
|
||||
|
||||
stats := dailyStats[dateKey]
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]any, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// calculateStats 计算统计数据
|
||||
|
||||
Reference in New Issue
Block a user