Merge branch 'dev-release'

This commit is contained in:
yangjianbo
2026-02-07 19:58:00 +08:00
62 changed files with 2674 additions and 564 deletions

View File

@@ -36,8 +36,8 @@ type UsageLogRepository interface {
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)

View File

@@ -1582,6 +1582,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
return changed, nil
}
// ForwardUpstream 透传请求到上游 Antigravity 服务
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
sessionID := getSessionID(c)
prefix := logPrefix(sessionID, account.Name)
// 获取上游配置
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if baseURL == "" || apiKey == "" {
return nil, fmt.Errorf("upstream account missing base_url or api_key")
}
baseURL = strings.TrimSuffix(baseURL, "/")
// 解析请求获取模型信息
var claudeReq antigravity.ClaudeRequest
if err := json.Unmarshal(body, &claudeReq); err != nil {
return nil, fmt.Errorf("parse claude request: %w", err)
}
if strings.TrimSpace(claudeReq.Model) == "" {
return nil, fmt.Errorf("missing model")
}
originalModel := claudeReq.Model
billingModel := originalModel
// 构建上游请求 URL
upstreamURL := baseURL + "/v1/messages"
// 创建请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create upstream request: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
// 透传 Claude 相关 headers
if v := c.GetHeader("anthropic-version"); v != "" {
req.Header.Set("anthropic-version", v)
}
if v := c.GetHeader("anthropic-beta"); v != "" {
req.Header.Set("anthropic-beta", v)
}
// 代理 URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
if err != nil {
log.Printf("%s upstream request failed: %v", prefix, err)
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 429 错误时标记账号限流
if resp.StatusCode == http.StatusTooManyRequests {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
}
// 透传上游错误
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(resp.StatusCode)
_, _ = c.Writer.Write(respBody)
return &ForwardResult{
Model: billingModel,
}, nil
}
// 处理成功响应(流式/非流式)
var usage *ClaudeUsage
var firstTokenMs *int
if claudeReq.Stream {
// 流式响应:透传
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
c.Status(http.StatusOK)
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
} else {
// 非流式响应:直接透传
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read upstream response: %w", err)
}
// 提取 usage
usage = s.extractClaudeUsage(respBody)
c.Header("Content-Type", resp.Header.Get("Content-Type"))
c.Status(http.StatusOK)
_, _ = c.Writer.Write(respBody)
}
// 构建计费结果
duration := time.Since(startTime)
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
return &ForwardResult{
Model: billingModel,
Stream: claudeReq.Stream,
Duration: duration,
FirstTokenMs: firstTokenMs,
Usage: ClaudeUsage{
InputTokens: usage.InputTokens,
OutputTokens: usage.OutputTokens,
CacheReadInputTokens: usage.CacheReadInputTokens,
CacheCreationInputTokens: usage.CacheCreationInputTokens,
},
}, nil
}
// streamUpstreamResponse 透传上游流式响应并提取 usage
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
usage := &ClaudeUsage{}
var firstTokenMs *int
var firstTokenRecorded bool
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, 64*1024)
scanner.Buffer(buf, 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
// 记录首 token 时间
if !firstTokenRecorded && len(line) > 0 {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
firstTokenRecorded = true
}
// 尝试从 message_delta 或 message_stop 事件提取 usage
if bytes.HasPrefix(line, []byte("data: ")) {
dataStr := bytes.TrimPrefix(line, []byte("data: "))
var event map[string]any
if json.Unmarshal(dataStr, &event) == nil {
if u, ok := event["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
}
}
}
// 透传行
_, _ = c.Writer.Write(line)
_, _ = c.Writer.Write([]byte("\n"))
c.Writer.Flush()
}
return usage, firstTokenMs
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
usage := &ClaudeUsage{}
var resp map[string]any
if json.Unmarshal(body, &resp) != nil {
return usage
}
if u, ok := resp["usage"].(map[string]any); ok {
if v, ok := u["input_tokens"].(float64); ok {
usage.InputTokens = int(v)
}
if v, ok := u["output_tokens"].(float64); ok {
usage.OutputTokens = int(v)
}
if v, ok := u["cache_read_input_tokens"].(float64); ok {
usage.CacheReadInputTokens = int(v)
}
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
}
return usage
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) {
startTime := time.Now()
@@ -1613,7 +1815,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
default:
@@ -2288,7 +2490,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
@@ -2309,7 +2512,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -2320,7 +2524,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
@@ -2445,7 +2649,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
usage := &ClaudeUsage{}
var firstTokenMs *int
@@ -2473,7 +2678,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -2484,7 +2690,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
@@ -2888,7 +3094,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
var firstTokenMs *int
var last map[string]any
@@ -2914,7 +3121,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -2925,7 +3133,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
// 上游数据间隔超时保护(防止上游挂起长期占用连接)
@@ -3068,7 +3276,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.settingService.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
// 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage
convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage {
@@ -3100,7 +3309,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -3111,7 +3321,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@@ -391,3 +392,37 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
require.Equal(t, http.StatusServiceUnavailable, failoverErr.StatusCode)
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
}
func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
gin.SetMode(gin.TestMode)
writer := httptest.NewRecorder()
c, _ := gin.CreateTestContext(writer)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n"))
_, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n"))
}()
svc := &AntigravityGatewayService{}
start := time.Now().Add(-10 * time.Millisecond)
usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start)
_ = pr.Close()
require.NotNil(t, usage)
require.Equal(t, 1, usage.InputTokens)
// 第二次事件覆盖 output_tokens
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
require.Equal(t, 4, usage.CacheCreationInputTokens)
if firstTokenMs == nil {
t.Fatalf("expected firstTokenMs to be set")
}
// 确保有透传输出
require.True(t, strings.Contains(writer.Body.String(), "data:"))
}

View File

@@ -6,8 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct {
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
@@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
@@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
randVal := rand.Float64()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl

View File

@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
return trend, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -4145,7 +4145,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
@@ -4164,7 +4165,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -4175,7 +4177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
@@ -4481,24 +4483,16 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
// replaceModelInResponseBody 替换响应体中的model字段
// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
return body
}
// RecordUsageInput 记录使用量的输入参数

View File

@@ -0,0 +1,52 @@
package service
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &GatewayService{
cfg: cfg,
rateLimitService: &RateLimitService{},
}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
pr, pw := io.Pipe()
resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr}
go func() {
defer func() { _ = pw.Close() }()
// Minimal SSE event to trigger parseSSEUsage
_, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n"))
_, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n"))
_, _ = pw.Write([]byte("data: [DONE]\n\n"))
}()
result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", nil, false)
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 3, result.usage.InputTokens)
require.Equal(t, 7, result.usage.OutputTokens)
}

View File

@@ -2,19 +2,7 @@ package service
import (
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
//go:embed prompts/codex_cli_instructions.md
@@ -77,12 +65,6 @@ type codexTransformResult struct {
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
result := codexTransformResult{}
// 工具续链需求会影响存储策略与 input 过滤逻辑。
@@ -216,54 +198,9 @@ func getNormalizedCodexModel(modelID string) string {
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
// 优先从 opencode 仓库缓存获取指令
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
// 若 opencode 指令可用,直接返回。
if opencodeInstructions != "" {
return opencodeInstructions
}
// 否则回退使用本地 Codex CLI 指令。
// 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt
// 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions避免读写缓存与外网依赖。
return getCodexCLIInstructions()
}
@@ -281,8 +218,8 @@ func GetCodexCLIInstructions() string {
}
// applyInstructions 处理 instructions 字段
// isCodexCLI=true: 仅补充缺失的 instructions使用 opencode 指令)
// isCodexCLI=false: 优先使用 opencode 指令覆盖
// isCodexCLI=true: 仅补充缺失的 instructions使用内置 Codex CLI 指令)
// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
if isCodexCLI {
return applyCodexCLIInstructions(reqBody)
@@ -291,13 +228,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
}
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
// 仅在 instructions 为空时添加 opencode 指令
// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源)
func applyCodexCLIInstructions(reqBody map[string]any) bool {
if !isInstructionsEmpty(reqBody) {
return false // 已有有效 instructions不修改
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
instructions := strings.TrimSpace(getCodexCLIInstructions())
if instructions != "" {
reqBody["instructions"] = instructions
return true
@@ -306,8 +243,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool {
return false
}
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
// 优先使用 opencode 指令覆盖
// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名)
// 优先使用内置 Codex CLI 指令覆盖
func applyOpenCodeInstructions(reqBody map[string]any) bool {
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
@@ -489,85 +426,3 @@ func normalizeCodexTools(reqBody map[string]any) bool {
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}

View File

@@ -1,18 +1,13 @@
package service
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
// 续链场景:保留 item_reference 与 id但不再强制 store=true。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.2",
@@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
// 续链场景:显式 store=false 不再强制为 true保持 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
// 显式 store=true 也会强制为 false。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
// 非续链场景:未设置 store 时默认 false并移除 input 中的 id。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
}
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"tools": []any{
@@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -187,88 +176,27 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input))
}
}
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
// Codex CLI 场景:已有 instructions 时保持不变
setupCodexCache(t)
// Codex CLI 场景:已有 instructions 时不修改
reqBody := map[string]any{
"model": "gpt-5.1",
"instructions": "user custom instructions",
"input": []any{},
"instructions": "existing instructions",
}
result := applyCodexOAuthTransform(reqBody, true)
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "user custom instructions", instructions)
// instructions 未变,但其他字段(如 store、stream可能被修改
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充内置指令
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, true)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.NotEmpty(t, instructions)
require.True(t, result.Modified)
}
func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"input": []any{},
}
result := applyCodexOAuthTransform(reqBody, false)
instructions, ok := reqBody["instructions"].(string)
require.True(t, ok)
require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容
require.True(t, result.Modified)
}
func setupCodexCache(t *testing.T) {
t.Helper()
// 使用临时 HOME 避免触发网络拉取 header。
// Windows 使用 USERPROFILEUnix 使用 HOME。
tempDir := t.TempDir()
t.Setenv("HOME", tempDir)
t.Setenv("USERPROFILE", tempDir)
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
meta := map[string]any{
"etag": "",
"lastFetch": time.Now().UTC().Format(time.RFC3339),
"lastChecked": time.Now().UnixMilli(),
}
data, err := json.Marshal(meta)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
require.Equal(t, "existing instructions", instructions)
// Modified 仍可能为 true因为其他字段被修改但 instructions 应保持不变
_ = result
}
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
// Codex CLI 场景:无 instructions 时补充默认值
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
@@ -284,8 +212,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T
}
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
// 非 Codex CLI 场景:使用 opencode 指令覆盖
setupCodexCache(t)
// 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖
reqBody := map[string]any{
"model": "gpt-5.1",

View File

@@ -24,6 +24,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -765,7 +767,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
// 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel)
@@ -969,6 +971,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
if usage == nil {
usage = &OpenAIUsage{}
}
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
return &OpenAIForwardResult{
@@ -1053,6 +1059,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("user-agent", customUA)
}
// 若开启 ForceCodexCLI则强制将上游 User-Agent 伪装为 Codex CLI。
// 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。
if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
req.Header.Set("user-agent", "codex_cli_rs/0.98.0")
}
// Ensure required headers exist
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
@@ -1233,7 +1245,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
scanBuf := getSSEScannerBuf64K()
scanner.Buffer(scanBuf[:0], maxLineSize)
type scanEvent struct {
line string
@@ -1252,7 +1265,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
go func(scanBuf *sseScannerBuf64K) {
defer putSSEScannerBuf64K(scanBuf)
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
@@ -1263,7 +1277,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
}(scanBuf)
defer close(done)
streamInterval := time.Duration(0)
@@ -1442,31 +1456,22 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// Replace model in response
if m, ok := event["model"].(string); ok && m == fromModel {
event["model"] = toModel
newData, err := json.Marshal(event)
// 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化
if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "model", toModel)
if err != nil {
return line
}
return "data: " + string(newData)
return "data: " + newData
}
// Check nested response
if response, ok := event["response"].(map[string]any); ok {
if m, ok := response["model"].(string); ok && m == fromModel {
response["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
// 检查嵌套的 response.model 字段
if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel {
newData, err := sjson.Set(data, "response.model", toModel)
if err != nil {
return line
}
return "data: " + newData
}
return line
@@ -1686,23 +1691,15 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
// 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化
if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel {
newBody, err := sjson.SetBytes(body, "model", toModel)
if err != nil {
return body
}
return newBody
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
return body
}
// OpenAIRecordUsageInput input for recording usage

View File

@@ -14,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type stubOpenAIAccountRepo struct {
@@ -1082,6 +1083,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) {
}
}
func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
Gateway: config.GatewayConfig{
StreamDataIntervalTimeout: 0,
StreamKeepaliveInterval: 0,
MaxLineSize: defaultMaxLineSize,
},
}
svc := &OpenAIGatewayService{cfg: cfg}
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
pr, pw := io.Pipe()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: pr,
Header: http.Header{},
}
go func() {
defer func() { _ = pw.Close() }()
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n"))
}()
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
_ = pr.Close()
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.usage)
require.Equal(t, 1, result.usage.InputTokens)
require.Equal(t, 2, result.usage.OutputTokens)
require.Equal(t, 3, result.usage.CacheReadInputTokens)
}
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -1165,3 +1203,226 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
t.Fatalf("expected non-allowlisted host to fail")
}
}
// ==================== P1-08 修复model 替换性能优化测试 ====================
func TestReplaceModelInSSELine(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
line string
from string
to string
expected string
}{
{
name: "顶层 model 字段替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "my-custom-model",
expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`,
},
{
name: "嵌套 response.model 替换",
line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`,
},
{
name: "model 不匹配时不替换",
line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段时不替换",
line: `data: {"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"chatcmpl-123","choices":[]}`,
},
{
name: "空 data 行",
line: `data: `,
from: "gpt-4o",
to: "my-model",
expected: `data: `,
},
{
name: "[DONE] 行",
line: `data: [DONE]`,
from: "gpt-4o",
to: "my-model",
expected: `data: [DONE]`,
},
{
name: "非 data: 前缀行",
line: `event: message`,
from: "gpt-4o",
to: "my-model",
expected: `event: message`,
},
{
name: "非法 JSON 不替换",
line: `data: {invalid json}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {invalid json}`,
},
{
name: "无空格 data: 格式",
line: `data:{"id":"x","model":"gpt-4o"}`,
from: "gpt-4o",
to: "my-model",
expected: `data: {"id":"x","model":"my-model"}`,
},
{
name: "model 名含特殊字符",
line: `data: {"model":"org/model-v2.1-beta"}`,
from: "org/model-v2.1-beta",
to: "custom/alias",
expected: `data: {"model":"custom/alias"}`,
},
{
name: "空行",
line: "",
from: "gpt-4o",
to: "my-model",
expected: "",
},
{
name: "保持其他字段不变",
line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`,
},
{
name: "顶层优先于嵌套:同时存在两个 model",
line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`,
from: "gpt-4o",
to: "replaced",
expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInSSEBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "多行 SSE body 替换",
body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n",
},
{
name: "无需替换的 body",
body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
from: "gpt-4o",
to: "alias",
expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n",
},
{
name: "混合 event 和 data 行",
body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n",
from: "gpt-4o",
to: "alias",
expected: "event: message\ndata: {\"model\":\"alias\"}\n\n",
},
{
name: "空 body",
body: "",
from: "gpt-4o",
to: "alias",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to)
require.Equal(t, tt.expected, got)
})
}
}
func TestReplaceModelInResponseBody(t *testing.T) {
svc := &OpenAIGatewayService{}
tests := []struct {
name string
body string
from string
to string
expected string
}{
{
name: "替换顶层 model",
body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`,
},
{
name: "model 不匹配不替换",
body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`,
},
{
name: "无 model 字段不替换",
body: `{"id":"chatcmpl-123","choices":[]}`,
from: "gpt-4o",
to: "alias",
expected: `{"id":"chatcmpl-123","choices":[]}`,
},
{
name: "非法 JSON 返回原值",
body: `not json`,
from: "gpt-4o",
to: "alias",
expected: `not json`,
},
{
name: "空 body 返回原值",
body: ``,
from: "gpt-4o",
to: "alias",
expected: ``,
},
{
name: "保持嵌套结构不变",
body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
from: "gpt-4o",
to: "alias",
expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to)
require.Equal(t, tt.expected, string(got))
})
}
}

View File

@@ -0,0 +1,24 @@
package service
import "sync"
const sseScannerBuf64KSize = 64 * 1024
type sseScannerBuf64K [sseScannerBuf64KSize]byte
var sseScannerBuf64KPool = sync.Pool{
New: func() any {
return new(sseScannerBuf64K)
},
}
func getSSEScannerBuf64K() *sseScannerBuf64K {
return sseScannerBuf64KPool.Get().(*sseScannerBuf64K)
}
func putSSEScannerBuf64K(buf *sseScannerBuf64K) {
if buf == nil {
return
}
sseScannerBuf64KPool.Put(buf)
}

View File

@@ -0,0 +1,19 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSSEScannerBuf64KPool_GetPutDoesNotPanic(t *testing.T) {
buf := getSSEScannerBuf64K()
require.NotNil(t, buf)
require.Equal(t, sseScannerBuf64KSize, len(buf[:]))
buf[0] = 1
putSSEScannerBuf64K(buf)
// 允许传入 nil确保不会 panic
putSSEScannerBuf64K(nil)
}

View File

@@ -4,10 +4,15 @@ import (
"context"
"fmt"
"log"
"math/rand/v2"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
// MaxExpiresAt is the maximum allowed expiration date (year 2099)
@@ -35,15 +40,76 @@ type SubscriptionService struct {
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService
// L1 缓存:加速中间件热路径的订阅查询
subCacheL1 *ristretto.Cache
subCacheGroup singleflight.Group
subCacheTTL time.Duration
subCacheJitter int // 抖动百分比
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
func NewSubscriptionService(groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, billingCacheService *BillingCacheService, cfg *config.Config) *SubscriptionService {
svc := &SubscriptionService{
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
svc.initSubCache(cfg)
return svc
}
// initSubCache 初始化订阅 L1 缓存
func (s *SubscriptionService) initSubCache(cfg *config.Config) {
if cfg == nil {
return
}
sc := cfg.SubscriptionCache
if sc.L1Size <= 0 || sc.L1TTLSeconds <= 0 {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(sc.L1Size) * 10,
MaxCost: int64(sc.L1Size),
BufferItems: 64,
})
if err != nil {
log.Printf("Warning: failed to init subscription L1 cache: %v", err)
return
}
s.subCacheL1 = cache
s.subCacheTTL = time.Duration(sc.L1TTLSeconds) * time.Second
s.subCacheJitter = sc.JitterPercent
}
// subCacheKey 生成订阅缓存 key热路径避免 fmt.Sprintf 开销)
func subCacheKey(userID, groupID int64) string {
return "sub:" + strconv.FormatInt(userID, 10) + ":" + strconv.FormatInt(groupID, 10)
}
// jitteredTTL 为 TTL 添加抖动,避免集中过期
func (s *SubscriptionService) jitteredTTL(ttl time.Duration) time.Duration {
if ttl <= 0 || s.subCacheJitter <= 0 {
return ttl
}
pct := s.subCacheJitter
if pct > 100 {
pct = 100
}
delta := float64(pct) / 100
factor := 1 - delta + rand.Float64()*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
// InvalidateSubCache 失效指定用户+分组的订阅 L1 缓存
func (s *SubscriptionService) InvalidateSubCache(userID, groupID int64) {
if s.subCacheL1 == nil {
return
}
s.subCacheL1.Del(subCacheKey(userID, groupID))
}
// AssignSubscriptionInput 分配订阅输入
@@ -81,6 +147,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -167,6 +234,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -188,6 +256,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 失效订阅缓存
s.InvalidateSubCache(input.UserID, input.GroupID)
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
@@ -297,6 +366,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
@@ -363,6 +433,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}
// 失效订阅缓存
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
@@ -381,12 +452,39 @@ func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubsc
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
// 使用 L1 缓存 + singleflight 加速中间件热路径。
// 返回缓存对象的浅拷贝,调用方可安全修改字段而不会污染缓存或触发 data race。
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
key := subCacheKey(userID, groupID)
// L1 缓存命中:返回浅拷贝
if s.subCacheL1 != nil {
if v, ok := s.subCacheL1.Get(key); ok {
if sub, ok := v.(*UserSubscription); ok {
cp := *sub
return &cp, nil
}
}
}
return sub, nil
// singleflight 防止并发击穿
value, err, _ := s.subCacheGroup.Do(key, func() (any, error) {
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 写入 L1 缓存
if s.subCacheL1 != nil {
_ = s.subCacheL1.SetWithTTL(key, sub, 1, s.jitteredTTL(s.subCacheTTL))
}
return sub, nil
})
if err != nil {
return nil, err
}
// singleflight 返回的也是缓存指针,需要浅拷贝
cp := *value.(*UserSubscription)
return &cp, nil
}
// ListUserSubscriptions 获取用户的所有订阅
@@ -521,9 +619,12 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
needsInvalidateCache = true
}
// 如果有窗口被重置,失效 Redis 缓存以保持一致性
if needsInvalidateCache && s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
// 如果有窗口被重置,失效缓存以保持一致性
if needsInvalidateCache {
s.InvalidateSubCache(sub.UserID, sub.GroupID)
if s.billingCacheService != nil {
_ = s.billingCacheService.InvalidateSubscription(ctx, sub.UserID, sub.GroupID)
}
}
return nil
@@ -544,6 +645,78 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSub
return nil
}
// ValidateAndCheckLimits 合并验证+限额检查(中间件热路径专用)
// 仅做内存检查,不触发 DB 写入。窗口重置的 DB 写入由 DoWindowMaintenance 异步完成。
// 返回 needsMaintenance 表示是否需要异步执行窗口维护。
func (s *SubscriptionService) ValidateAndCheckLimits(sub *UserSubscription, group *Group) (needsMaintenance bool, err error) {
// 1. 验证订阅状态
if sub.Status == SubscriptionStatusExpired {
return false, ErrSubscriptionExpired
}
if sub.Status == SubscriptionStatusSuspended {
return false, ErrSubscriptionSuspended
}
if sub.IsExpired() {
return false, ErrSubscriptionExpired
}
// 2. 内存中修正过期窗口的用量,确保 CheckUsageLimits 不会误拒绝用户
// 实际的 DB 窗口重置由 DoWindowMaintenance 异步完成
if sub.NeedsDailyReset() {
sub.DailyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsWeeklyReset() {
sub.WeeklyUsageUSD = 0
needsMaintenance = true
}
if sub.NeedsMonthlyReset() {
sub.MonthlyUsageUSD = 0
needsMaintenance = true
}
if !sub.IsWindowActivated() {
needsMaintenance = true
}
// 3. 检查用量限额
if !sub.CheckDailyLimit(group, 0) {
return needsMaintenance, ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, 0) {
return needsMaintenance, ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, 0) {
return needsMaintenance, ErrMonthlyLimitExceeded
}
return needsMaintenance, nil
}
// DoWindowMaintenance 异步执行窗口维护(激活+重置)
// 使用独立 context不受请求取消影响。
// 注意:此方法仅在 ValidateAndCheckLimits 返回 needsMaintenance=true 时调用,
// 而 IsExpired()=true 的订阅在 ValidateAndCheckLimits 中已被拦截返回错误,
// 因此进入此方法的订阅一定未过期,无需处理过期状态同步。
func (s *SubscriptionService) DoWindowMaintenance(sub *UserSubscription) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 激活窗口(首次使用时)
if !sub.IsWindowActivated() {
if err := s.CheckAndActivateWindow(ctx, sub); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
}
// 重置过期窗口
if err := s.CheckAndResetWindows(ctx, sub); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 失效 L1 缓存,确保后续请求拿到更新后的数据
s.InvalidateSubCache(sub.UserID, sub.GroupID)
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)

View File

@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -3,6 +3,8 @@ package service
import (
"context"
"fmt"
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
type UserService struct {
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
billingCache BillingCache
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
return &UserService{
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
billingCache: billingCache,
}
}
@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCache != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
}
}()
}
return nil
}

View File

@@ -0,0 +1,186 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// --- mock: UserRepository ---
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
if m.updateBalanceFn != nil {
return m.updateBalanceFn(ctx, id, amount)
}
return m.updateBalanceErr
}
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
// --- mock: APIKeyAuthCacheInvalidator ---
type mockAuthCacheInvalidator struct {
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
}
// --- mock: BillingCache ---
type mockBillingCache struct {
invalidateErr error
invalidateCallCount atomic.Int64
invalidatedUserIDs []int64
mu sync.Mutex
}
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
m.invalidateCallCount.Add(1)
m.mu.Lock()
defer m.mu.Unlock()
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
return m.invalidateErr
}
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
return nil, nil
}
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
return nil
}
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
return nil
}
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
return nil
}
// --- 测试 ---
func TestUpdateBalance_Success(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 42, 100.0)
require.NoError(t, err)
// 等待异步 goroutine 完成
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
cache.mu.Lock()
defer cache.mu.Unlock()
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
}
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil) // billingCache = nil
err := svc.UpdateBalance(context.Background(), 1, 50.0)
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
}
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
repo := &mockUserRepo{}
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 99, 200.0)
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
// 等待异步 goroutine 完成(即使失败也应调用)
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
}
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
cache := &mockBillingCache{}
svc := NewUserService(repo, nil, cache)
err := svc.UpdateBalance(context.Background(), 1, 100.0)
require.Error(t, err, "repo 失败时应返回错误")
require.Contains(t, err.Error(), "update balance")
// repo 失败时不应触发缓存失效
time.Sleep(100 * time.Millisecond)
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
"repo 失败时不应调用 InvalidateUserBalance")
}
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
err := svc.UpdateBalance(context.Background(), 77, 300.0)
require.NoError(t, err)
// 验证 auth cache 同步失效
auth.mu.Lock()
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
auth.mu.Unlock()
// 验证 billing cache 异步失效
require.Eventually(t, func() bool {
return cache.invalidateCallCount.Load() == 1
}, 2*time.Second, 10*time.Millisecond)
}
func TestNewUserService_FieldsAssignment(t *testing.T) {
repo := &mockUserRepo{}
auth := &mockAuthCacheInvalidator{}
cache := &mockBillingCache{}
svc := NewUserService(repo, auth, cache)
require.NotNil(t, svc)
require.Equal(t, repo, svc.userRepo)
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}