Merge branch 'dev-release'
This commit is contained in:
@@ -36,8 +36,8 @@ type UsageLogRepository interface {
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
|
||||
@@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// ForwardUpstream 透传请求到上游 Antigravity 服务
|
||||
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
|
||||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
// 获取上游配置
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// 解析请求获取模型信息
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||
|
||||
// 透传 Claude 相关 headers
|
||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||
req.Header.Set("anthropic-version", v)
|
||||
}
|
||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||
req.Header.Set("anthropic-beta", v)
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 429 错误时标记账号限流
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
|
||||
}
|
||||
|
||||
// 透传上游错误
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 处理成功响应(流式/非流式)
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
if claudeReq.Stream {
|
||||
// 流式响应:透传
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
|
||||
} else {
|
||||
// 非流式响应:直接透传
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
|
||||
// 提取 usage
|
||||
usage = s.extractClaudeUsage(respBody)
|
||||
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
}
|
||||
|
||||
// 构建计费结果
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamUpstreamResponse 透传上游流式响应并提取 usage
|
||||
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
var firstTokenRecorded bool
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
// 记录首 token 时间
|
||||
if !firstTokenRecorded && len(line) > 0 {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
firstTokenRecorded = true
|
||||
}
|
||||
|
||||
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
dataStr := bytes.TrimPrefix(line, []byte("data: "))
|
||||
var event map[string]any
|
||||
if json.Unmarshal(dataStr, &event) == nil {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 透传行
|
||||
_, _ = c.Writer.Write(line)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
return usage, firstTokenMs
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||
usage := &ClaudeUsage{}
|
||||
var resp map[string]any
|
||||
if json.Unmarshal(body, &resp) != nil {
|
||||
return usage
|
||||
}
|
||||
if u, ok := resp["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Usage: ClaudeUsage{},
|
||||
Model: originalModel,
|
||||
Stream: false,
|
||||
Duration: time.Since(time.Now()),
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: nil,
|
||||
}, nil
|
||||
default:
|
||||
@@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
|
||||
@@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
@@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
var firstTokenMs *int
|
||||
var last map[string]any
|
||||
@@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
|
||||
@@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
|
||||
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
|
||||
@@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
||||
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
|
||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||
}
|
||||
|
||||
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
|
||||
}()
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
start := time.Now().Add(-10 * time.Millisecond)
|
||||
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
|
||||
_ = pr.Close()
|
||||
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 1, usage.InputTokens)
|
||||
// 第二次事件覆盖 output_tokens
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 3, usage.CacheReadInputTokens)
|
||||
require.Equal(t, 4, usage.CacheCreationInputTokens)
|
||||
|
||||
if firstTokenMs == nil {
|
||||
t.Fatalf("expected firstTokenMs to be set")
|
||||
}
|
||||
// 确保有透传输出
|
||||
require.True(t, strings.Contains(writer.Body.String(), "data:"))
|
||||
}
|
||||
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
var (
|
||||
jitterRandMu sync.Mutex
|
||||
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
|
||||
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
jitterRandMu.Lock()
|
||||
randVal := jitterRand.Float64()
|
||||
jitterRandMu.Unlock()
|
||||
randVal := rand.Float64()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
|
||||
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
}
|
||||
|
||||
// replaceModelInResponseBody 替换响应体中的model字段
|
||||
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
|
||||
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
|
||||
52
backend/internal/service/gateway_service_streaming_test.go
Normal file
52
backend/internal/service/gateway_service_streaming_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: cfg,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// Minimal SSE event to trigger parseSSEUsage
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n"))
|
||||
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false)
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 3, result.usage.InputTokens)
|
||||
require.Equal(t, 7, result.usage.OutputTokens)
|
||||
}
|
||||
@@ -2,19 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||
codexCacheTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
//go:embed prompts/codex_cli_instructions.md
|
||||
@@ -77,12 +65,6 @@ type codexTransformResult struct {
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type opencodeCacheMetadata struct {
|
||||
ETag string `json:"etag"`
|
||||
LastFetch string `json:"lastFetch,omitempty"`
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||
cacheDir := codexCachePath("")
|
||||
if cacheDir == "" {
|
||||
return ""
|
||||
}
|
||||
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||
|
||||
var cachedContent string
|
||||
if content, ok := readFile(cacheFile); ok {
|
||||
cachedContent = content
|
||||
}
|
||||
|
||||
var meta opencodeCacheMetadata
|
||||
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||
return cachedContent
|
||||
}
|
||||
}
|
||||
|
||||
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||
return cachedContent
|
||||
}
|
||||
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||
_ = writeFile(cacheFile, content)
|
||||
meta = opencodeCacheMetadata{
|
||||
ETag: etag,
|
||||
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||
LastChecked: time.Now().UnixMilli(),
|
||||
}
|
||||
_ = writeJSON(metaFile, meta)
|
||||
return content
|
||||
}
|
||||
|
||||
return cachedContent
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
// 优先从 opencode 仓库缓存获取指令。
|
||||
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
|
||||
// 若 opencode 指令可用,直接返回。
|
||||
if opencodeInstructions != "" {
|
||||
return opencodeInstructions
|
||||
}
|
||||
|
||||
// 否则回退使用本地 Codex CLI 指令。
|
||||
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。
|
||||
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
|
||||
}
|
||||
|
||||
// applyInstructions 处理 instructions 字段
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
|
||||
// isCodexCLI=false: 优先使用 opencode 指令覆盖
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令)
|
||||
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
|
||||
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
if isCodexCLI {
|
||||
return applyCodexCLIInstructions(reqBody)
|
||||
@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
}
|
||||
|
||||
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
|
||||
// 仅在 instructions 为空时添加 opencode 指令
|
||||
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
|
||||
func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
if !isInstructionsEmpty(reqBody) {
|
||||
return false // 已有有效 instructions,不修改
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
instructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
|
||||
// 优先使用 opencode 指令覆盖
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
|
||||
// 优先使用内置 Codex CLI 指令覆盖
|
||||
func applyOpenCodeInstructions(reqBody map[string]any) bool {
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func codexCachePath(filename string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||
if filename == "" {
|
||||
return cacheDir
|
||||
}
|
||||
return filepath.Join(cacheDir, filename)
|
||||
}
|
||||
|
||||
func readFile(path string) (string, bool) {
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(data), true
|
||||
}
|
||||
|
||||
func writeFile(path, content string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty cache path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(path string, value any) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty json path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sub2api-codex")
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", resp.StatusCode, err
|
||||
}
|
||||
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
// 显式 store=true 也会强制为 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"tools": []any{
|
||||
@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||
// Codex CLI 场景:已有 instructions 时保持不变
|
||||
setupCodexCache(t)
|
||||
// Codex CLI 场景:已有 instructions 时不修改
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"instructions": "user custom instructions",
|
||||
"input": []any{},
|
||||
"instructions": "existing instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true)
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "user custom instructions", instructions)
|
||||
// instructions 未变,但其他字段(如 store、stream)可能被修改
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
|
||||
// Codex CLI 场景:无 instructions 时补充内置指令
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true)
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, instructions)
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header)
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// 使用临时 HOME 避免触发网络拉取 header。
|
||||
// Windows 使用 USERPROFILE,Unix 使用 HOME。
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("HOME", tempDir)
|
||||
t.Setenv("USERPROFILE", tempDir)
|
||||
|
||||
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||
|
||||
meta := map[string]any{
|
||||
"etag": "",
|
||||
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||
"lastChecked": time.Now().UnixMilli(),
|
||||
}
|
||||
data, err := json.Marshal(meta)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||
require.Equal(t, "existing instructions", instructions)
|
||||
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
|
||||
_ = result
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
|
||||
// Codex CLI 场景:无 instructions 时补充默认值
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用 opencode 指令覆盖
|
||||
setupCodexCache(t)
|
||||
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
|
||||
@@ -24,6 +24,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if usage == nil {
|
||||
usage = &OpenAIUsage{}
|
||||
}
|
||||
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。
|
||||
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
|
||||
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
|
||||
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
var lastReadAt int64
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
go func() {
|
||||
go func(scanBuf *sseScannerBuf64K) {
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
defer close(events)
|
||||
for scanner.Scan() {
|
||||
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
|
||||
@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
if err := scanner.Err(); err != nil {
|
||||
_ = sendEvent(scanEvent{err: err})
|
||||
}
|
||||
}()
|
||||
}(scanBuf)
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
// 检查嵌套的 response.model 字段
|
||||
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
|
||||
newData, err := sjson.Set(data, "response.model", toModel)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + newData
|
||||
}
|
||||
|
||||
return line
|
||||
@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
|
||||
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
|
||||
newBody, err := sjson.SetBytes(body, "model", toModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
return body
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubOpenAIAccountRepo struct {
|
||||
@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
|
||||
}()
|
||||
|
||||
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.usage)
|
||||
require.Equal(t, 1, result.usage.InputTokens)
|
||||
require.Equal(t, 2, result.usage.OutputTokens)
|
||||
require.Equal(t, 3, result.usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
t.Fatalf("expected non-allowlisted host to fail")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== P1-08 修复:model 替换性能优化测试 ====================
|
||||
|
||||
func TestReplaceModelInSSELine(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "顶层 model 字段替换",
|
||||
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-custom-model",
|
||||
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "嵌套 response.model 替换",
|
||||
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
|
||||
},
|
||||
{
|
||||
name: "model 不匹配时不替换",
|
||||
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "无 model 字段时不替换",
|
||||
line: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "空 data 行",
|
||||
line: `data: `,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: `,
|
||||
},
|
||||
{
|
||||
name: "[DONE] 行",
|
||||
line: `data: [DONE]`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: [DONE]`,
|
||||
},
|
||||
{
|
||||
name: "非 data: 前缀行",
|
||||
line: `event: message`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `event: message`,
|
||||
},
|
||||
{
|
||||
name: "非法 JSON 不替换",
|
||||
line: `data: {invalid json}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {invalid json}`,
|
||||
},
|
||||
{
|
||||
name: "无空格 data: 格式",
|
||||
line: `data:{"id":"x","model":"gpt-4o"}`,
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: `data: {"id":"x","model":"my-model"}`,
|
||||
},
|
||||
{
|
||||
name: "model 名含特殊字符",
|
||||
line: `data: {"model":"org/model-v2.1-beta"}`,
|
||||
from: "org/model-v2.1-beta",
|
||||
to: "custom/alias",
|
||||
expected: `data: {"model":"custom/alias"}`,
|
||||
},
|
||||
{
|
||||
name: "空行",
|
||||
line: "",
|
||||
from: "gpt-4o",
|
||||
to: "my-model",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "保持其他字段不变",
|
||||
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
|
||||
},
|
||||
{
|
||||
name: "顶层优先于嵌套:同时存在两个 model",
|
||||
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
|
||||
from: "gpt-4o",
|
||||
to: "replaced",
|
||||
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceModelInSSEBody(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "多行 SSE body 替换",
|
||||
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
|
||||
},
|
||||
{
|
||||
name: "无需替换的 body",
|
||||
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
|
||||
},
|
||||
{
|
||||
name: "混合 event 和 data 行",
|
||||
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
|
||||
},
|
||||
{
|
||||
name: "空 body",
|
||||
body: "",
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceModelInResponseBody(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
from string
|
||||
to string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "替换顶层 model",
|
||||
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "model 不匹配不替换",
|
||||
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "无 model 字段不替换",
|
||||
body: `{"id":"chatcmpl-123","choices":[]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"id":"chatcmpl-123","choices":[]}`,
|
||||
},
|
||||
{
|
||||
name: "非法 JSON 返回原值",
|
||||
body: `not json`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `not json`,
|
||||
},
|
||||
{
|
||||
name: "空 body 返回原值",
|
||||
body: ``,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: ``,
|
||||
},
|
||||
{
|
||||
name: "保持嵌套结构不变",
|
||||
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||
from: "gpt-4o",
|
||||
to: "alias",
|
||||
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
|
||||
require.Equal(t, tt.expected, string(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
24
backend/internal/service/sse_scanner_buffer_pool.go
Normal file
24
backend/internal/service/sse_scanner_buffer_pool.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package service
|
||||
|
||||
import "sync"
|
||||
|
||||
const sseScannerBuf64KSize = 64 * 1024
|
||||
|
||||
type sseScannerBuf64K [sseScannerBuf64KSize]byte
|
||||
|
||||
var sseScannerBuf64KPool = sync.Pool{
|
||||
New: func() any {
|
||||
return new(sseScannerBuf64K)
|
||||
},
|
||||
}
|
||||
|
||||
func getSSEScannerBuf64K() *sseScannerBuf64K {
|
||||
return sseScannerBuf64KPool.Get().(*sseScannerBuf64K)
|
||||
}
|
||||
|
||||
func putSSEScannerBuf64K(buf *sseScannerBuf64K) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
sseScannerBuf64KPool.Put(buf)
|
||||
}
|
||||
19
backend/internal/service/sse_scanner_buffer_pool_test.go
Normal file
19
backend/internal/service/sse_scanner_buffer_pool_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) {
|
||||
buf := getSSEScannerBuf64K()
|
||||
require.NotNil(t, buf)
|
||||
require.Equal(t, sseScannerBuf64KSize, len(buf[:]))
|
||||
|
||||
buf[0] = 1
|
||||
putSSEScannerBuf64K(buf)
|
||||
|
||||
// 允许传入 nil,确保不会 panic
|
||||
putSSEScannerBuf64K(nil)
|
||||
}
|
||||
@@ -4,10 +4,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
|
||||
@@ -35,15 +40,76 @@ type SubscriptionService struct {
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
|
||||
// L1 缓存:加速中间件热路径的订阅查询
|
||||
subCacheL1 *ristretto.Cache
|
||||
subCacheGroup singleflight.Group
|
||||
subCacheTTL time.Duration
|
||||
subCacheJitter int // 抖动百分比
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
|
||||
svc := &SubscriptionService{
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
svc.initSubCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// initSubCache 初始化订阅 L1 缓存
|
||||
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
sc := cfg.SubscriptionCache
|
||||
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(sc.L1Size) * 10,
|
||||
MaxCost: int64(sc.L1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
|
||||
return
|
||||
}
|
||||
s.subCacheL1 = cache
|
||||
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
|
||||
s.subCacheJitter = sc.JitterPercent
|
||||
}
|
||||
|
||||
// subCacheKey 生成订阅缓存 key(热路径,避免 fmt.Sprintf 开销)
|
||||
func subCacheKey(userID, groupID int64) string {
|
||||
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
|
||||
}
|
||||
|
||||
// jitteredTTL 为 TTL 添加抖动,避免集中过期
|
||||
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 || s.subCacheJitter <= 0 {
|
||||
return ttl
|
||||
}
|
||||
pct := s.subCacheJitter
|
||||
if pct > 100 {
|
||||
pct = 100
|
||||
}
|
||||
delta := float64(pct) / 100
|
||||
factor := 1 - delta + rand.Float64()*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
|
||||
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
|
||||
if s.subCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
s.subCacheL1.Del(subCacheKey(userID, groupID))
|
||||
}
|
||||
|
||||
// AssignSubscriptionInput 分配订阅输入
|
||||
@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(input.UserID, input.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := input.UserID, input.GroupID
|
||||
go func() {
|
||||
@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// 失效订阅缓存
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
userID, groupID := sub.UserID, sub.GroupID
|
||||
go func() {
|
||||
@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
// 使用 L1 缓存 + singleflight 加速中间件热路径。
|
||||
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
key := subCacheKey(userID, groupID)
|
||||
|
||||
// L1 缓存命中:返回浅拷贝
|
||||
if s.subCacheL1 != nil {
|
||||
if v, ok := s.subCacheL1.Get(key); ok {
|
||||
if sub, ok := v.(*UserSubscription); ok {
|
||||
cp := *sub
|
||||
return &cp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return sub, nil
|
||||
|
||||
// singleflight 防止并发击穿
|
||||
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
// 写入 L1 缓存
|
||||
if s.subCacheL1 != nil {
|
||||
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
|
||||
}
|
||||
return sub, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// singleflight 返回的也是缓存指针,需要浅拷贝
|
||||
cp := *value.(*UserSubscription)
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
@@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
|
||||
needsInvalidateCache = true
|
||||
}
|
||||
|
||||
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
|
||||
if needsInvalidateCache && s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
// 如果有窗口被重置,失效缓存以保持一致性
|
||||
if needsInvalidateCache {
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
if s.billingCacheService != nil {
|
||||
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
|
||||
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
|
||||
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
|
||||
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
|
||||
// 1. 验证订阅状态
|
||||
if sub.Status == SubscriptionStatusExpired {
|
||||
return false, ErrSubscriptionExpired
|
||||
}
|
||||
if sub.Status == SubscriptionStatusSuspended {
|
||||
return false, ErrSubscriptionSuspended
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
return false, ErrSubscriptionExpired
|
||||
}
|
||||
|
||||
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
|
||||
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
|
||||
if sub.NeedsDailyReset() {
|
||||
sub.DailyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if sub.NeedsWeeklyReset() {
|
||||
sub.WeeklyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if sub.NeedsMonthlyReset() {
|
||||
sub.MonthlyUsageUSD = 0
|
||||
needsMaintenance = true
|
||||
}
|
||||
if !sub.IsWindowActivated() {
|
||||
needsMaintenance = true
|
||||
}
|
||||
|
||||
// 3. 检查用量限额
|
||||
if !sub.CheckDailyLimit(group, 0) {
|
||||
return needsMaintenance, ErrDailyLimitExceeded
|
||||
}
|
||||
if !sub.CheckWeeklyLimit(group, 0) {
|
||||
return needsMaintenance, ErrWeeklyLimitExceeded
|
||||
}
|
||||
if !sub.CheckMonthlyLimit(group, 0) {
|
||||
return needsMaintenance, ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return needsMaintenance, nil
|
||||
}
|
||||
|
||||
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
|
||||
// 使用独立 context,不受请求取消影响。
|
||||
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
|
||||
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
|
||||
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
|
||||
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 激活窗口(首次使用时)
|
||||
if !sub.IsWindowActivated() {
|
||||
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重置过期窗口
|
||||
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 失效 L1 缓存,确保后续请求拿到更新后的数据
|
||||
s.InvalidateSubCache(sub.UserID, sub.GroupID)
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
|
||||
@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
||||
}
|
||||
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
|
||||
type UserService struct {
|
||||
userRepo UserRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
billingCache BillingCache
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
billingCache: billingCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCache != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
186
backend/internal/service/user_service_test.go
Normal file
186
backend/internal/service/user_service_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock: UserRepository ---
|
||||
|
||||
type mockUserRepo struct {
|
||||
updateBalanceErr error
|
||||
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
if m.updateBalanceFn != nil {
|
||||
return m.updateBalanceFn(ctx, id, amount)
|
||||
}
|
||||
return m.updateBalanceErr
|
||||
}
|
||||
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
// --- mock: APIKeyAuthCacheInvalidator ---
|
||||
|
||||
type mockAuthCacheInvalidator struct {
|
||||
invalidatedUserIDs []int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
|
||||
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||
}
|
||||
|
||||
// --- mock: BillingCache ---
|
||||
|
||||
type mockBillingCache struct {
|
||||
invalidateErr error
|
||||
invalidateCallCount atomic.Int64
|
||||
invalidatedUserIDs []int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
|
||||
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
|
||||
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
|
||||
m.invalidateCallCount.Add(1)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||
return m.invalidateErr
|
||||
}
|
||||
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- 测试 ---
|
||||
|
||||
func TestUpdateBalance_Success(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 42, 100.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 等待异步 goroutine 完成
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
|
||||
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
svc := NewUserService(repo, nil, nil) // billingCache = nil
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 1, 50.0)
|
||||
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 99, 200.0)
|
||||
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
|
||||
|
||||
// 等待异步 goroutine 完成(即使失败也应调用)
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
|
||||
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, nil, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 1, 100.0)
|
||||
require.Error(t, err, "repo 失败时应返回错误")
|
||||
require.Contains(t, err.Error(), "update balance")
|
||||
|
||||
// repo 失败时不应触发缓存失效
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
|
||||
"repo 失败时不应调用 InvalidateUserBalance")
|
||||
}
|
||||
|
||||
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
auth := &mockAuthCacheInvalidator{}
|
||||
cache := &mockBillingCache{}
|
||||
svc := NewUserService(repo, auth, cache)
|
||||
|
||||
err := svc.UpdateBalance(context.Background(), 77, 300.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 auth cache 同步失效
|
||||
auth.mu.Lock()
|
||||
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
|
||||
auth.mu.Unlock()
|
||||
|
||||
// 验证 billing cache 异步失效
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.invalidateCallCount.Load() == 1
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestNewUserService_FieldsAssignment(t *testing.T) {
|
||||
repo := &mockUserRepo{}
|
||||
auth := &mockAuthCacheInvalidator{}
|
||||
cache := &mockBillingCache{}
|
||||
|
||||
svc := NewUserService(repo, auth, cache)
|
||||
require.NotNil(t, svc)
|
||||
require.Equal(t, repo, svc.userRepo)
|
||||
require.Equal(t, auth, svc.authCacheInvalidator)
|
||||
require.Equal(t, cache, svc.billingCache)
|
||||
}
|
||||
Reference in New Issue
Block a user