perf(gateway): 优化热点路径并补齐高覆盖测试
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -121,7 +121,6 @@ AGENTS.md
|
|||||||
scripts
|
scripts
|
||||||
.code-review-state
|
.code-review-state
|
||||||
openspec/
|
openspec/
|
||||||
docs/
|
|
||||||
code-reviews/
|
code-reviews/
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
backend/cmd/server/server
|
backend/cmd/server/server
|
||||||
|
|||||||
@@ -423,6 +423,11 @@ type GatewayConfig struct {
|
|||||||
|
|
||||||
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
|
// UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker)
|
||||||
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
|
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
|
||||||
|
|
||||||
|
// UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒)
|
||||||
|
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
|
||||||
|
// ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒)
|
||||||
|
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
// GatewayUsageRecordConfig 使用量记录异步队列配置
|
||||||
@@ -1175,6 +1180,8 @@ func setDefaults() {
|
|||||||
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
|
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
|
||||||
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
|
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
|
||||||
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
|
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
|
||||||
|
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
|
||||||
|
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
|
||||||
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
// TLS指纹伪装配置(默认关闭,需要账号级别单独启用)
|
||||||
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
|
||||||
viper.SetDefault("concurrency.ping_interval", 10)
|
viper.SetDefault("concurrency.ping_interval", 10)
|
||||||
@@ -1751,6 +1758,12 @@ func (c *Config) Validate() error {
|
|||||||
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
|
return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
|
||||||
|
return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive")
|
||||||
|
}
|
||||||
|
if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 {
|
||||||
|
return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30")
|
||||||
|
}
|
||||||
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
|
if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 {
|
||||||
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
|
return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1010,6 +1010,16 @@ func TestValidateConfigErrors(t *testing.T) {
|
|||||||
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
|
mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 },
|
||||||
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
|
wantErr: "gateway.usage_record.auto_scale_check_interval_seconds",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "gateway user group rate cache ttl",
|
||||||
|
mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 },
|
||||||
|
wantErr: "gateway.user_group_rate_cache_ttl_seconds",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gateway models list cache ttl range",
|
||||||
|
mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 },
|
||||||
|
wantErr: "gateway.models_list_cache_ttl_seconds",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "gateway scheduling sticky waiting",
|
name: "gateway scheduling sticky waiting",
|
||||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||||
|
|||||||
@@ -64,4 +64,3 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
|
|||||||
require.NotNil(t, created.Extra)
|
require.NotNil(t, created.Extra)
|
||||||
require.Equal(t, true, created.Extra["anthropic_passthrough"])
|
require.Equal(t, true, created.Extra["anthropic_passthrough"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -243,6 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
var sessionBoundAccountID int64
|
var sessionBoundAccountID int64
|
||||||
if sessionKey != "" {
|
if sessionKey != "" {
|
||||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||||
|
if sessionBoundAccountID > 0 {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator()
|
|||||||
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||||
// 返回更新后的 context
|
// 返回更新后的 context
|
||||||
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||||
// 解析请求体为 map
|
if c == nil || c.Request == nil {
|
||||||
var bodyMap map[string]any
|
return
|
||||||
if len(body) > 0 {
|
}
|
||||||
_ = json.Unmarshal(body, &bodyMap)
|
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
|
||||||
|
if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
|
||||||
|
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证是否为 Claude Code 客户端
|
isClaudeCode := false
|
||||||
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
if !strings.Contains(c.Request.URL.Path, "messages") {
|
||||||
|
// 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。
|
||||||
|
isClaudeCode = true
|
||||||
|
} else {
|
||||||
|
// 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。
|
||||||
|
var bodyMap map[string]any
|
||||||
|
if len(body) > 0 {
|
||||||
|
_ = json.Unmarshal(body, &bodyMap)
|
||||||
|
}
|
||||||
|
isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||||
|
}
|
||||||
|
|
||||||
// 更新 request context
|
// 更新 request context
|
||||||
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||||
@@ -223,21 +238,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
|||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Try immediate acquire first (avoid unnecessary wait)
|
|
||||||
var result *service.AcquireResult
|
|
||||||
var err error
|
|
||||||
if slotType == "user" {
|
|
||||||
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
|
||||||
} else {
|
|
||||||
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if result.Acquired {
|
|
||||||
return result.ReleaseFunc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine if ping is needed (streaming + ping format defined)
|
// Determine if ping is needed (streaming + ping format defined)
|
||||||
needPing := isStream && h.pingFormat != ""
|
needPing := isStream && h.pingFormat != ""
|
||||||
|
|
||||||
|
|||||||
252
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
252
backend/internal/handler/gateway_helper_hotpath_test.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type helperConcurrencyCacheStub struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
accountSeq []bool
|
||||||
|
userSeq []bool
|
||||||
|
|
||||||
|
accountAcquireCalls int
|
||||||
|
userAcquireCalls int
|
||||||
|
accountReleaseCalls int
|
||||||
|
userReleaseCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.accountAcquireCalls++
|
||||||
|
if len(s.accountSeq) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
v := s.accountSeq[0]
|
||||||
|
s.accountSeq = s.accountSeq[1:]
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.accountReleaseCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.userAcquireCalls++
|
||||||
|
if len(s.userSeq) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
v := s.userSeq[0]
|
||||||
|
s.userSeq = s.userSeq[1:]
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.userReleaseCalls++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||||
|
out := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||||
|
for _, acc := range accounts {
|
||||||
|
out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||||
|
out := make(map[int64]*service.UserLoadInfo, len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
out[user.ID] = &service.UserLoadInfo{UserID: user.ID}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(method, path, nil)
|
||||||
|
return c, rec
|
||||||
|
}
|
||||||
|
|
||||||
|
func validClaudeCodeBodyJSON() []byte {
|
||||||
|
return []byte(`{
|
||||||
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
|
"system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}],
|
||||||
|
"metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}
|
||||||
|
}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) {
|
||||||
|
t.Run("non_cli_user_agent_sets_false", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
c.Request.Header.Set("User-Agent", "curl/8.6.0")
|
||||||
|
|
||||||
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||||
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cli_non_messages_path_sets_true", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodGet, "/v1/models")
|
||||||
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
|
||||||
|
SetClaudeCodeClientContext(c, nil)
|
||||||
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
c.Request.Header.Set("X-App", "claude-code")
|
||||||
|
c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24")
|
||||||
|
c.Request.Header.Set("anthropic-version", "2023-06-01")
|
||||||
|
|
||||||
|
SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON())
|
||||||
|
require.True(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
c.Request.Header.Set("User-Agent", "claude-cli/1.0.1")
|
||||||
|
// 缺少严格校验所需 header + body 字段
|
||||||
|
SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`))
|
||||||
|
require.False(t, service.IsClaudeCodeClient(c.Request.Context()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) {
|
||||||
|
cache := &helperConcurrencyCacheStub{
|
||||||
|
accountSeq: []bool{false, true},
|
||||||
|
userSeq: []bool{false, true},
|
||||||
|
}
|
||||||
|
concurrency := service.NewConcurrencyService(cache)
|
||||||
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||||
|
|
||||||
|
t.Run("account_slot_acquired_after_retry", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, release)
|
||||||
|
require.False(t, streamStarted)
|
||||||
|
release()
|
||||||
|
require.GreaterOrEqual(t, cache.accountAcquireCalls, 2)
|
||||||
|
require.GreaterOrEqual(t, cache.accountReleaseCalls, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("user_slot_acquired_after_retry", func(t *testing.T) {
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, release)
|
||||||
|
release()
|
||||||
|
require.GreaterOrEqual(t, cache.userAcquireCalls, 2)
|
||||||
|
require.GreaterOrEqual(t, cache.userReleaseCalls, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) {
|
||||||
|
cache := &helperConcurrencyCacheStub{
|
||||||
|
accountSeq: []bool{false, false, false},
|
||||||
|
}
|
||||||
|
concurrency := service.NewConcurrencyService(cache)
|
||||||
|
|
||||||
|
t.Run("timeout_returns_concurrency_error", func(t *testing.T) {
|
||||||
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted)
|
||||||
|
require.Nil(t, release)
|
||||||
|
var cErr *ConcurrencyError
|
||||||
|
require.ErrorAs(t, err, &cErr)
|
||||||
|
require.True(t, cErr.IsTimeout)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) {
|
||||||
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond)
|
||||||
|
c, rec := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted)
|
||||||
|
require.Nil(t, release)
|
||||||
|
var cErr *ConcurrencyError
|
||||||
|
require.ErrorAs(t, err, &cErr)
|
||||||
|
require.True(t, cErr.IsTimeout)
|
||||||
|
require.True(t, streamStarted)
|
||||||
|
require.Contains(t, rec.Body.String(), ":\n\n")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) {
|
||||||
|
errCache := &helperConcurrencyCacheStubWithError{
|
||||||
|
err: errors.New("redis unavailable"),
|
||||||
|
}
|
||||||
|
concurrency := service.NewConcurrencyService(errCache)
|
||||||
|
helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond)
|
||||||
|
c, _ := newHelperTestContext(http.MethodPost, "/v1/messages")
|
||||||
|
streamStarted := false
|
||||||
|
release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted)
|
||||||
|
require.Nil(t, release)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "redis unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
type helperConcurrencyCacheStubWithError struct {
|
||||||
|
helperConcurrencyCacheStub
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return false, s.err
|
||||||
|
}
|
||||||
@@ -263,6 +263,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
var sessionBoundAccountID int64
|
var sessionBoundAccountID int64
|
||||||
if sessionKey != "" {
|
if sessionKey != "" {
|
||||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||||
|
if sessionBoundAccountID > 0 {
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
// === Gemini 内容摘要会话 Fallback 逻辑 ===
|
||||||
|
|||||||
@@ -41,9 +41,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type opsErrorLogJob struct {
|
type opsErrorLogJob struct {
|
||||||
ops *service.OpsService
|
ops *service.OpsService
|
||||||
entry *service.OpsInsertErrorLogInput
|
entry *service.OpsInsertErrorLogInput
|
||||||
requestBody []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -58,6 +57,7 @@ var (
|
|||||||
opsErrorLogEnqueued atomic.Int64
|
opsErrorLogEnqueued atomic.Int64
|
||||||
opsErrorLogDropped atomic.Int64
|
opsErrorLogDropped atomic.Int64
|
||||||
opsErrorLogProcessed atomic.Int64
|
opsErrorLogProcessed atomic.Int64
|
||||||
|
opsErrorLogSanitized atomic.Int64
|
||||||
|
|
||||||
opsErrorLogLastDropLogAt atomic.Int64
|
opsErrorLogLastDropLogAt atomic.Int64
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout)
|
||||||
_ = job.ops.RecordError(ctx, job.entry, job.requestBody)
|
_ = job.ops.RecordError(ctx, job.entry, nil)
|
||||||
cancel()
|
cancel()
|
||||||
opsErrorLogProcessed.Add(1)
|
opsErrorLogProcessed.Add(1)
|
||||||
}()
|
}()
|
||||||
@@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) {
|
func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) {
|
||||||
if ops == nil || entry == nil {
|
if ops == nil || entry == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo
|
|||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}:
|
case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}:
|
||||||
opsErrorLogQueueLen.Add(1)
|
opsErrorLogQueueLen.Add(1)
|
||||||
opsErrorLogEnqueued.Add(1)
|
opsErrorLogEnqueued.Add(1)
|
||||||
default:
|
default:
|
||||||
@@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 {
|
|||||||
return opsErrorLogProcessed.Load()
|
return opsErrorLogProcessed.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpsErrorLogSanitizedTotal() int64 {
|
||||||
|
return opsErrorLogSanitized.Load()
|
||||||
|
}
|
||||||
|
|
||||||
func maybeLogOpsErrorLogDrop() {
|
func maybeLogOpsErrorLogDrop() {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
|
|
||||||
@@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() {
|
|||||||
queueCap := OpsErrorLogQueueCapacity()
|
queueCap := OpsErrorLogQueueCapacity()
|
||||||
|
|
||||||
log.Printf(
|
log.Printf(
|
||||||
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)",
|
"[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)",
|
||||||
queued,
|
queued,
|
||||||
queueCap,
|
queueCap,
|
||||||
opsErrorLogEnqueued.Load(),
|
opsErrorLogEnqueued.Load(),
|
||||||
opsErrorLogDropped.Load(),
|
opsErrorLogDropped.Load(),
|
||||||
opsErrorLogProcessed.Load(),
|
opsErrorLogProcessed.Load(),
|
||||||
|
opsErrorLogSanitized.Load(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -267,6 +272,22 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||||
|
if c == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, ok := c.Get(opsRequestBodyKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
raw, ok := v.([]byte)
|
||||||
|
if !ok || len(raw) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
|
||||||
|
opsErrorLogSanitized.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||||
if c == nil || accountID <= 0 {
|
if c == nil || accountID <= 0 {
|
||||||
return
|
return
|
||||||
@@ -544,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
entry.ClientIP = &clientIP
|
entry.ClientIP = &clientIP
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody []byte
|
|
||||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
|
||||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
|
||||||
requestBody = b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
|
||||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||||
@@ -560,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
enqueueOpsErrorLog(ops, entry)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -724,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
entry.ClientIP = &clientIP
|
entry.ClientIP = &clientIP
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody []byte
|
|
||||||
if v, ok := c.Get(opsRequestBodyKey); ok {
|
|
||||||
if b, ok := v.([]byte); ok && len(b) > 0 {
|
|
||||||
requestBody = b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||||
// Do NOT store Authorization/Cookie/etc.
|
// Do NOT store Authorization/Cookie/etc.
|
||||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
|
||||||
enqueueOpsErrorLog(ops, entry, requestBody)
|
enqueueOpsErrorLog(ops, entry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
175
backend/internal/handler/ops_error_logger_test.go
Normal file
175
backend/internal/handler/ops_error_logger_test.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetOpsErrorLoggerStateForTest(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
ch := opsErrorLogQueue
|
||||||
|
opsErrorLogQueue = nil
|
||||||
|
opsErrorLogStopping = true
|
||||||
|
opsErrorLogMu.Unlock()
|
||||||
|
|
||||||
|
if ch != nil {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
opsErrorLogWorkersWg.Wait()
|
||||||
|
|
||||||
|
opsErrorLogOnce = sync.Once{}
|
||||||
|
opsErrorLogStopOnce = sync.Once{}
|
||||||
|
opsErrorLogWorkersWg = sync.WaitGroup{}
|
||||||
|
opsErrorLogMu = sync.RWMutex{}
|
||||||
|
opsErrorLogStopping = false
|
||||||
|
|
||||||
|
opsErrorLogQueueLen.Store(0)
|
||||||
|
opsErrorLogEnqueued.Store(0)
|
||||||
|
opsErrorLogDropped.Store(0)
|
||||||
|
opsErrorLogProcessed.Store(0)
|
||||||
|
opsErrorLogSanitized.Store(0)
|
||||||
|
opsErrorLogLastDropLogAt.Store(0)
|
||||||
|
|
||||||
|
opsErrorLogShutdownCh = make(chan struct{})
|
||||||
|
opsErrorLogShutdownOnce = sync.Once{}
|
||||||
|
opsErrorLogDrained.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
setOpsRequestContext(c, "claude-3", false, raw)
|
||||||
|
|
||||||
|
entry := &service.OpsInsertErrorLogInput{}
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
|
||||||
|
require.NotNil(t, entry.RequestBodyBytes)
|
||||||
|
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||||
|
require.NotNil(t, entry.RequestBodyJSON)
|
||||||
|
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
|
||||||
|
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
|
||||||
|
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
raw := []byte("not-json")
|
||||||
|
setOpsRequestContext(c, "claude-3", false, raw)
|
||||||
|
|
||||||
|
entry := &service.OpsInsertErrorLogInput{}
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
|
||||||
|
require.Nil(t, entry.RequestBodyJSON)
|
||||||
|
require.NotNil(t, entry.RequestBodyBytes)
|
||||||
|
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||||
|
require.False(t, entry.RequestBodyTruncated)
|
||||||
|
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
|
||||||
|
// 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。
|
||||||
|
opsErrorLogOnce.Do(func() {})
|
||||||
|
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
opsErrorLogQueue = make(chan opsErrorLogJob, 1)
|
||||||
|
opsErrorLogMu.Unlock()
|
||||||
|
|
||||||
|
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||||
|
|
||||||
|
enqueueOpsErrorLog(ops, entry)
|
||||||
|
enqueueOpsErrorLog(ops, entry)
|
||||||
|
|
||||||
|
require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal())
|
||||||
|
require.Equal(t, int64(1), OpsErrorLogDroppedTotal())
|
||||||
|
require.Equal(t, int64(1), OpsErrorLogQueueLength())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
entry := &service.OpsInsertErrorLogInput{}
|
||||||
|
attachOpsRequestBodyToEntry(nil, entry)
|
||||||
|
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
|
||||||
|
// 无请求体 key
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
require.Nil(t, entry.RequestBodyJSON)
|
||||||
|
require.Nil(t, entry.RequestBodyBytes)
|
||||||
|
require.False(t, entry.RequestBodyTruncated)
|
||||||
|
|
||||||
|
// 错误类型
|
||||||
|
c.Set(opsRequestBodyKey, "not-bytes")
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
require.Nil(t, entry.RequestBodyJSON)
|
||||||
|
require.Nil(t, entry.RequestBodyBytes)
|
||||||
|
|
||||||
|
// 空 bytes
|
||||||
|
c.Set(opsRequestBodyKey, []byte{})
|
||||||
|
attachOpsRequestBodyToEntry(c, entry)
|
||||||
|
require.Nil(t, entry.RequestBodyJSON)
|
||||||
|
require.Nil(t, entry.RequestBodyBytes)
|
||||||
|
|
||||||
|
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
|
||||||
|
ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"}
|
||||||
|
|
||||||
|
// nil 入参分支
|
||||||
|
enqueueOpsErrorLog(nil, entry)
|
||||||
|
enqueueOpsErrorLog(ops, nil)
|
||||||
|
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||||
|
|
||||||
|
// shutdown 分支
|
||||||
|
close(opsErrorLogShutdownCh)
|
||||||
|
enqueueOpsErrorLog(ops, entry)
|
||||||
|
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||||
|
|
||||||
|
// stopping 分支
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
opsErrorLogStopping = true
|
||||||
|
opsErrorLogMu.Unlock()
|
||||||
|
enqueueOpsErrorLog(ops, entry)
|
||||||
|
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||||
|
|
||||||
|
// queue nil 分支(防止启动 worker 干扰)
|
||||||
|
resetOpsErrorLoggerStateForTest(t)
|
||||||
|
opsErrorLogOnce.Do(func() {})
|
||||||
|
opsErrorLogMu.Lock()
|
||||||
|
opsErrorLogQueue = nil
|
||||||
|
opsErrorLogMu.Unlock()
|
||||||
|
enqueueOpsErrorLog(ops, entry)
|
||||||
|
require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal())
|
||||||
|
}
|
||||||
@@ -44,4 +44,8 @@ const (
|
|||||||
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
||||||
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
||||||
SingleAccountRetry Key = "ctx_single_account_retry"
|
SingleAccountRetry Key = "ctx_single_account_retry"
|
||||||
|
|
||||||
|
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
|
||||||
|
// Service 层可复用该值,避免同请求链路重复读取 Redis。
|
||||||
|
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -915,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。
|
||||||
|
// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。
|
||||||
|
func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
|
||||||
|
result := make(map[int64]*usagestats.AccountStats, len(accountIDs))
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
account_id,
|
||||||
|
COUNT(*) as requests,
|
||||||
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
|
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
|
FROM usage_logs
|
||||||
|
WHERE account_id = ANY($1) AND created_at >= $2
|
||||||
|
GROUP BY account_id
|
||||||
|
`
|
||||||
|
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var accountID int64
|
||||||
|
stats := &usagestats.AccountStats{}
|
||||||
|
if err := rows.Scan(
|
||||||
|
&accountID,
|
||||||
|
&stats.Requests,
|
||||||
|
&stats.Tokens,
|
||||||
|
&stats.Cost,
|
||||||
|
&stats.StandardCost,
|
||||||
|
&stats.UserCost,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[accountID] = stats
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, accountID := range accountIDs {
|
||||||
|
if _, ok := result[accountID]; !ok {
|
||||||
|
result[accountID] = &usagestats.AccountStats{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TrendDataPoint represents a single point in trend data
|
// TrendDataPoint represents a single point in trend data
|
||||||
type TrendDataPoint = usagestats.TrendDataPoint
|
type TrendDataPoint = usagestats.TrendDataPoint
|
||||||
|
|
||||||
|
|||||||
755
backend/internal/service/gateway_hotpath_optimization_test.go
Normal file
755
backend/internal/service/gateway_hotpath_optimization_test.go
Normal file
@@ -0,0 +1,755 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userGroupRateRepoHotpathStub struct {
|
||||||
|
UserGroupRateRepository
|
||||||
|
|
||||||
|
rate *float64
|
||||||
|
err error
|
||||||
|
wait <-chan struct{}
|
||||||
|
calls atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||||
|
s.calls.Add(1)
|
||||||
|
if s.wait != nil {
|
||||||
|
<-s.wait
|
||||||
|
}
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return s.rate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageLogWindowBatchRepoStub struct {
|
||||||
|
UsageLogRepository
|
||||||
|
|
||||||
|
batchResult map[int64]*usagestats.AccountStats
|
||||||
|
batchErr error
|
||||||
|
batchCalls atomic.Int64
|
||||||
|
|
||||||
|
singleResult map[int64]*usagestats.AccountStats
|
||||||
|
singleErr error
|
||||||
|
singleCalls atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
|
||||||
|
s.batchCalls.Add(1)
|
||||||
|
if s.batchErr != nil {
|
||||||
|
return nil, s.batchErr
|
||||||
|
}
|
||||||
|
out := make(map[int64]*usagestats.AccountStats, len(accountIDs))
|
||||||
|
for _, id := range accountIDs {
|
||||||
|
if stats, ok := s.batchResult[id]; ok {
|
||||||
|
out[id] = stats
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
|
||||||
|
s.singleCalls.Add(1)
|
||||||
|
if s.singleErr != nil {
|
||||||
|
return nil, s.singleErr
|
||||||
|
}
|
||||||
|
if stats, ok := s.singleResult[accountID]; ok {
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
return &usagestats.AccountStats{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionLimitCacheHotpathStub struct {
|
||||||
|
SessionLimitCache
|
||||||
|
|
||||||
|
batchData map[int64]float64
|
||||||
|
batchErr error
|
||||||
|
|
||||||
|
setData map[int64]float64
|
||||||
|
setErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
|
||||||
|
if s.batchErr != nil {
|
||||||
|
return nil, s.batchErr
|
||||||
|
}
|
||||||
|
out := make(map[int64]float64, len(accountIDs))
|
||||||
|
for _, id := range accountIDs {
|
||||||
|
if v, ok := s.batchData[id]; ok {
|
||||||
|
out[id] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
|
||||||
|
if s.setErr != nil {
|
||||||
|
return s.setErr
|
||||||
|
}
|
||||||
|
if s.setData == nil {
|
||||||
|
s.setData = make(map[int64]float64)
|
||||||
|
}
|
||||||
|
s.setData[accountID] = cost
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelsListAccountRepoStub struct {
|
||||||
|
AccountRepository
|
||||||
|
|
||||||
|
byGroup map[int64][]Account
|
||||||
|
all []Account
|
||||||
|
err error
|
||||||
|
|
||||||
|
listByGroupCalls atomic.Int64
|
||||||
|
listAllCalls atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type stickyGatewayCacheHotpathStub struct {
|
||||||
|
GatewayCache
|
||||||
|
|
||||||
|
stickyID int64
|
||||||
|
getCalls atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||||
|
s.getCalls.Add(1)
|
||||||
|
if s.stickyID > 0 {
|
||||||
|
return s.stickyID, nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||||
|
s.listByGroupCalls.Add(1)
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
accounts, ok := s.byGroup[groupID]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := make([]Account, len(accounts))
|
||||||
|
copy(out, accounts)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||||
|
s.listAllCalls.Add(1)
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
out := make([]Account, len(s.all))
|
||||||
|
copy(out, s.all)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resetGatewayHotpathStatsForTest() {
|
||||||
|
windowCostPrefetchCacheHitTotal.Store(0)
|
||||||
|
windowCostPrefetchCacheMissTotal.Store(0)
|
||||||
|
windowCostPrefetchBatchSQLTotal.Store(0)
|
||||||
|
windowCostPrefetchFallbackTotal.Store(0)
|
||||||
|
windowCostPrefetchErrorTotal.Store(0)
|
||||||
|
|
||||||
|
userGroupRateCacheHitTotal.Store(0)
|
||||||
|
userGroupRateCacheMissTotal.Store(0)
|
||||||
|
userGroupRateCacheLoadTotal.Store(0)
|
||||||
|
userGroupRateCacheSFSharedTotal.Store(0)
|
||||||
|
userGroupRateCacheFallbackTotal.Store(0)
|
||||||
|
|
||||||
|
modelsListCacheHitTotal.Store(0)
|
||||||
|
modelsListCacheMissTotal.Store(0)
|
||||||
|
modelsListCacheStoreTotal.Store(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
rate := 1.7
|
||||||
|
unblock := make(chan struct{})
|
||||||
|
repo := &userGroupRateRepoHotpathStub{
|
||||||
|
rate: &rate,
|
||||||
|
wait: unblock,
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
userGroupRateRepo: repo,
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
UserGroupRateCacheTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const concurrent = 12
|
||||||
|
results := make([]float64, concurrent)
|
||||||
|
start := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(concurrent)
|
||||||
|
for i := 0; i < concurrent; i++ {
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
<-start
|
||||||
|
results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(start)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
close(unblock)
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for _, got := range results {
|
||||||
|
require.Equal(t, rate, got)
|
||||||
|
}
|
||||||
|
require.Equal(t, int64(1), repo.calls.Load())
|
||||||
|
|
||||||
|
// 再次读取应命中缓存,不再回源。
|
||||||
|
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2)
|
||||||
|
require.Equal(t, rate, got)
|
||||||
|
require.Equal(t, int64(1), repo.calls.Load())
|
||||||
|
|
||||||
|
hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats()
|
||||||
|
require.GreaterOrEqual(t, hit, int64(1))
|
||||||
|
require.Equal(t, int64(12), miss)
|
||||||
|
require.Equal(t, int64(1), load)
|
||||||
|
require.GreaterOrEqual(t, sfShared, int64(1))
|
||||||
|
require.Equal(t, int64(0), fallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
repo := &userGroupRateRepoHotpathStub{
|
||||||
|
err: errors.New("db down"),
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
userGroupRateRepo: repo,
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
UserGroupRateCacheTTLSeconds: 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25)
|
||||||
|
require.Equal(t, 1.25, got)
|
||||||
|
require.Equal(t, int64(1), repo.calls.Load())
|
||||||
|
|
||||||
|
_, _, _, _, fallback := GatewayUserGroupRateCacheStats()
|
||||||
|
require.Equal(t, int64(1), fallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
repo := &userGroupRateRepoHotpathStub{
|
||||||
|
err: errors.New("should not be called"),
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
userGroupRateRepo: repo,
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
}
|
||||||
|
key := "101:202"
|
||||||
|
svc.userGroupRateCache.Set(key, 2.3, time.Minute)
|
||||||
|
|
||||||
|
got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1)
|
||||||
|
require.Equal(t, 2.3, got)
|
||||||
|
|
||||||
|
hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats()
|
||||||
|
require.Equal(t, int64(1), hit)
|
||||||
|
require.Equal(t, int64(0), miss)
|
||||||
|
require.Equal(t, int64(0), load)
|
||||||
|
require.Equal(t, int64(0), fallback)
|
||||||
|
require.Equal(t, int64(0), repo.calls.Load())
|
||||||
|
|
||||||
|
// 无 repo 时直接返回分组默认倍率
|
||||||
|
svc2 := &GatewayService{
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
}
|
||||||
|
svc2.userGroupRateCache.Set(key, 1.9, time.Minute)
|
||||||
|
require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
|
||||||
|
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4))
|
||||||
|
svc2.userGroupRateCache.Delete(key)
|
||||||
|
require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||||
|
windowEnd := windowStart.Add(5 * time.Hour)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
SessionWindowStart: &windowStart,
|
||||||
|
SessionWindowEnd: &windowEnd,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeSetupToken,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
SessionWindowStart: &windowStart,
|
||||||
|
SessionWindowEnd: &windowEnd,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &sessionLimitCacheHotpathStub{
|
||||||
|
batchData: map[int64]float64{
|
||||||
|
1: 11.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageLogWindowBatchRepoStub{
|
||||||
|
batchResult: map[int64]*usagestats.AccountStats{
|
||||||
|
2: {StandardCost: 22.0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
sessionLimitCache: cache,
|
||||||
|
usageLogRepo: repo,
|
||||||
|
}
|
||||||
|
|
||||||
|
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||||
|
require.NotNil(t, outCtx)
|
||||||
|
|
||||||
|
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
|
||||||
|
require.True(t, ok1)
|
||||||
|
require.Equal(t, 11.0, cost1)
|
||||||
|
|
||||||
|
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
|
||||||
|
require.True(t, ok2)
|
||||||
|
require.Equal(t, 22.0, cost2)
|
||||||
|
|
||||||
|
_, ok3 := windowCostFromPrefetchContext(outCtx, 3)
|
||||||
|
require.False(t, ok3)
|
||||||
|
|
||||||
|
require.Equal(t, int64(1), repo.batchCalls.Load())
|
||||||
|
require.Equal(t, 22.0, cache.setData[2])
|
||||||
|
|
||||||
|
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||||
|
require.Equal(t, int64(1), hit)
|
||||||
|
require.Equal(t, int64(1), miss)
|
||||||
|
require.Equal(t, int64(1), batchSQL)
|
||||||
|
require.Equal(t, int64(0), fallback)
|
||||||
|
require.Equal(t, int64(0), errCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||||
|
windowEnd := windowStart.Add(5 * time.Hour)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
SessionWindowStart: &windowStart,
|
||||||
|
SessionWindowEnd: &windowEnd,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeSetupToken,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
SessionWindowStart: &windowStart,
|
||||||
|
SessionWindowEnd: &windowEnd,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &sessionLimitCacheHotpathStub{
|
||||||
|
batchData: map[int64]float64{
|
||||||
|
1: 11.0,
|
||||||
|
2: 22.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := &usageLogWindowBatchRepoStub{}
|
||||||
|
svc := &GatewayService{
|
||||||
|
sessionLimitCache: cache,
|
||||||
|
usageLogRepo: repo,
|
||||||
|
}
|
||||||
|
|
||||||
|
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||||
|
cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1)
|
||||||
|
cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2)
|
||||||
|
require.True(t, ok1)
|
||||||
|
require.True(t, ok2)
|
||||||
|
require.Equal(t, 11.0, cost1)
|
||||||
|
require.Equal(t, 22.0, cost2)
|
||||||
|
require.Equal(t, int64(0), repo.batchCalls.Load())
|
||||||
|
require.Equal(t, int64(0), repo.singleCalls.Load())
|
||||||
|
|
||||||
|
hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||||
|
require.Equal(t, int64(2), hit)
|
||||||
|
require.Equal(t, int64(0), miss)
|
||||||
|
require.Equal(t, int64(0), batchSQL)
|
||||||
|
require.Equal(t, int64(0), fallback)
|
||||||
|
require.Equal(t, int64(0), errCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour)
|
||||||
|
windowEnd := windowStart.Add(5 * time.Hour)
|
||||||
|
accounts := []Account{
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeSetupToken,
|
||||||
|
Extra: map[string]any{"window_cost_limit": 100.0},
|
||||||
|
SessionWindowStart: &windowStart,
|
||||||
|
SessionWindowEnd: &windowEnd,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := &sessionLimitCacheHotpathStub{}
|
||||||
|
repo := &usageLogWindowBatchRepoStub{
|
||||||
|
batchErr: errors.New("batch failed"),
|
||||||
|
singleResult: map[int64]*usagestats.AccountStats{
|
||||||
|
2: {StandardCost: 33.0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
sessionLimitCache: cache,
|
||||||
|
usageLogRepo: repo,
|
||||||
|
}
|
||||||
|
|
||||||
|
outCtx := svc.withWindowCostPrefetch(context.Background(), accounts)
|
||||||
|
cost, ok := windowCostFromPrefetchContext(outCtx, 2)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 33.0, cost)
|
||||||
|
require.Equal(t, int64(1), repo.batchCalls.Load())
|
||||||
|
require.Equal(t, int64(1), repo.singleCalls.Load())
|
||||||
|
|
||||||
|
_, _, _, fallback, errCount := GatewayWindowCostPrefetchStats()
|
||||||
|
require.Equal(t, int64(1), fallback)
|
||||||
|
require.Equal(t, int64(1), errCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
groupID := int64(9)
|
||||||
|
repo := &modelsListAccountRepoStub{
|
||||||
|
byGroup: map[int64][]Account{
|
||||||
|
groupID: {
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||||
|
"claude-3-5-haiku": "claude-3-5-haiku",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCacheTTL: time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||||
|
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1)
|
||||||
|
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||||
|
|
||||||
|
// TTL 内再次请求应命中缓存,不回源。
|
||||||
|
models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||||
|
require.Equal(t, models1, models2)
|
||||||
|
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||||
|
|
||||||
|
// 更新仓储数据,但缓存未失效前应继续返回旧值。
|
||||||
|
repo.byGroup[groupID] = []Account{
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-3-7-sonnet": "claude-3-7-sonnet",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||||
|
require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3)
|
||||||
|
require.Equal(t, int64(1), repo.listByGroupCalls.Load())
|
||||||
|
|
||||||
|
svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic)
|
||||||
|
models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic)
|
||||||
|
require.Equal(t, []string{"claude-3-7-sonnet"}, models4)
|
||||||
|
require.Equal(t, int64(2), repo.listByGroupCalls.Load())
|
||||||
|
|
||||||
|
hit, miss, store := GatewayModelsListCacheStats()
|
||||||
|
require.Equal(t, int64(2), hit)
|
||||||
|
require.Equal(t, int64(2), miss)
|
||||||
|
require.Equal(t, int64(2), store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) {
|
||||||
|
resetGatewayHotpathStatsForTest()
|
||||||
|
|
||||||
|
errRepo := &modelsListAccountRepoStub{
|
||||||
|
err: errors.New("db error"),
|
||||||
|
}
|
||||||
|
svcErr := &GatewayService{
|
||||||
|
accountRepo: errRepo,
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCacheTTL: time.Minute,
|
||||||
|
}
|
||||||
|
require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, ""))
|
||||||
|
|
||||||
|
okRepo := &modelsListAccountRepoStub{
|
||||||
|
all: []Account{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-3-5-sonnet": "claude-3-5-sonnet",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Platform: PlatformGemini,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-2.5-pro": "gemini-2.5-pro",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svcOK := &GatewayService{
|
||||||
|
accountRepo: okRepo,
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCacheTTL: time.Minute,
|
||||||
|
}
|
||||||
|
models := svcOK.GetAvailableModels(context.Background(), nil, "")
|
||||||
|
require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models)
|
||||||
|
require.Equal(t, int64(1), okRepo.listAllCalls.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) {
|
||||||
|
t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) {
|
||||||
|
require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil))
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
UserGroupRateCacheTTLSeconds: 45,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("resolve_models_list_cache_ttl", func(t *testing.T) {
|
||||||
|
require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil))
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
ModelsListCacheTTLSeconds: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) {
|
||||||
|
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO()))
|
||||||
|
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background()))
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123))
|
||||||
|
require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx))
|
||||||
|
|
||||||
|
ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456)
|
||||||
|
require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2))
|
||||||
|
|
||||||
|
ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid")
|
||||||
|
require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("window_cost_from_prefetch_context", func(t *testing.T) {
|
||||||
|
require.Equal(t, false, func() bool {
|
||||||
|
_, ok := windowCostFromPrefetchContext(context.TODO(), 0)
|
||||||
|
return ok
|
||||||
|
}())
|
||||||
|
require.Equal(t, false, func() bool {
|
||||||
|
_, ok := windowCostFromPrefetchContext(context.Background(), 1)
|
||||||
|
return ok
|
||||||
|
}())
|
||||||
|
|
||||||
|
ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{
|
||||||
|
9: 12.34,
|
||||||
|
})
|
||||||
|
cost, ok := windowCostFromPrefetchContext(ctx, 9)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 12.34, cost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) {
|
||||||
|
svc := &GatewayService{
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
}
|
||||||
|
group9 := int64(9)
|
||||||
|
group10 := int64(10)
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
|
||||||
|
svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute)
|
||||||
|
|
||||||
|
t.Run("invalidate_group_and_platform", func(t *testing.T) {
|
||||||
|
svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic)
|
||||||
|
_, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||||
|
require.False(t, found)
|
||||||
|
_, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||||
|
require.True(t, stillFound)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalidate_group_only", func(t *testing.T) {
|
||||||
|
svc.InvalidateAvailableModelsCache(&group9, "")
|
||||||
|
_, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||||
|
_, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||||
|
require.False(t, foundA)
|
||||||
|
require.False(t, foundB)
|
||||||
|
_, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
|
||||||
|
require.True(t, foundOtherGroup)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalidate_platform_only", func(t *testing.T) {
|
||||||
|
// 重建数据后仅按 platform 失效
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute)
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute)
|
||||||
|
svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute)
|
||||||
|
|
||||||
|
svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic)
|
||||||
|
_, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic))
|
||||||
|
_, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic))
|
||||||
|
_, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini))
|
||||||
|
require.False(t, found9Anthropic)
|
||||||
|
require.False(t, found10Anthropic)
|
||||||
|
require.True(t, found9Gemini)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) {
|
||||||
|
now := time.Now().Add(-time.Minute)
|
||||||
|
account := Account{
|
||||||
|
ID: 88,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
Concurrency: 4,
|
||||||
|
Priority: 1,
|
||||||
|
LastUsedAt: &now,
|
||||||
|
}
|
||||||
|
|
||||||
|
repo := stubOpenAIAccountRepo{accounts: []Account{account}}
|
||||||
|
concurrency := NewConcurrencyService(stubConcurrencyCache{})
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
RunMode: config.RunModeStandard,
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
Scheduling: config.GatewaySchedulingConfig{
|
||||||
|
LoadBatchEnabled: true,
|
||||||
|
StickySessionMaxWaiting: 3,
|
||||||
|
StickySessionWaitTimeout: time.Second,
|
||||||
|
FallbackWaitTimeout: time.Second,
|
||||||
|
FallbackMaxWaiting: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic)
|
||||||
|
|
||||||
|
t.Run("without_prefetch_reads_cache_once", func(t *testing.T) {
|
||||||
|
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: concurrency,
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCacheTTL: time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, account.ID, result.Account.ID)
|
||||||
|
require.Equal(t, int64(1), cache.getCalls.Load())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with_prefetch_skips_cache_read", func(t *testing.T) {
|
||||||
|
cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID}
|
||||||
|
svc := &GatewayService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cache: cache,
|
||||||
|
cfg: cfg,
|
||||||
|
concurrencyService: concurrency,
|
||||||
|
userGroupRateCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCache: gocache.New(time.Minute, time.Minute),
|
||||||
|
modelsListCacheTTL: time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID)
|
||||||
|
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.NotNil(t, result.Account)
|
||||||
|
require.Equal(t, account.ID, result.Account.ID)
|
||||||
|
require.Equal(t, int64(0), cache.getCalls.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -24,12 +24,15 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -44,6 +47,9 @@ const (
|
|||||||
// separator between system blocks, we add "\n\n" at concatenation time.
|
// separator between system blocks, we add "\n\n" at concatenation time.
|
||||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||||
|
|
||||||
|
defaultUserGroupRateCacheTTL = 30 * time.Second
|
||||||
|
defaultModelsListCacheTTL = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -62,6 +68,53 @@ type accountWithLoad struct {
|
|||||||
|
|
||||||
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
var ForceCacheBillingContextKey = forceCacheBillingKeyType{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
windowCostPrefetchCacheHitTotal atomic.Int64
|
||||||
|
windowCostPrefetchCacheMissTotal atomic.Int64
|
||||||
|
windowCostPrefetchBatchSQLTotal atomic.Int64
|
||||||
|
windowCostPrefetchFallbackTotal atomic.Int64
|
||||||
|
windowCostPrefetchErrorTotal atomic.Int64
|
||||||
|
|
||||||
|
userGroupRateCacheHitTotal atomic.Int64
|
||||||
|
userGroupRateCacheMissTotal atomic.Int64
|
||||||
|
userGroupRateCacheLoadTotal atomic.Int64
|
||||||
|
userGroupRateCacheSFSharedTotal atomic.Int64
|
||||||
|
userGroupRateCacheFallbackTotal atomic.Int64
|
||||||
|
|
||||||
|
modelsListCacheHitTotal atomic.Int64
|
||||||
|
modelsListCacheMissTotal atomic.Int64
|
||||||
|
modelsListCacheStoreTotal atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) {
|
||||||
|
return windowCostPrefetchCacheHitTotal.Load(),
|
||||||
|
windowCostPrefetchCacheMissTotal.Load(),
|
||||||
|
windowCostPrefetchBatchSQLTotal.Load(),
|
||||||
|
windowCostPrefetchFallbackTotal.Load(),
|
||||||
|
windowCostPrefetchErrorTotal.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) {
|
||||||
|
return userGroupRateCacheHitTotal.Load(),
|
||||||
|
userGroupRateCacheMissTotal.Load(),
|
||||||
|
userGroupRateCacheLoadTotal.Load(),
|
||||||
|
userGroupRateCacheSFSharedTotal.Load(),
|
||||||
|
userGroupRateCacheFallbackTotal.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) {
|
||||||
|
return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneStringSlice(src []string) []string {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := make([]string, len(src))
|
||||||
|
copy(dst, src)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
// IsForceCacheBilling 检查是否启用强制缓存计费
|
// IsForceCacheBilling 检查是否启用强制缓存计费
|
||||||
func IsForceCacheBilling(ctx context.Context) bool {
|
func IsForceCacheBilling(ctx context.Context) bool {
|
||||||
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
|
v, _ := ctx.Value(ForceCacheBillingContextKey).(bool)
|
||||||
@@ -302,6 +355,42 @@ func derefGroupID(groupID *int64) int64 {
|
|||||||
return *groupID
|
return *groupID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration {
|
||||||
|
if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 {
|
||||||
|
return defaultUserGroupRateCacheTTL
|
||||||
|
}
|
||||||
|
return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveModelsListCacheTTL(cfg *config.Config) time.Duration {
|
||||||
|
if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 {
|
||||||
|
return defaultModelsListCacheTTL
|
||||||
|
}
|
||||||
|
return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelsListCacheKey(groupID *int64, platform string) string {
|
||||||
|
return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform))
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefetchedStickyAccountIDFromContext(ctx context.Context) int64 {
|
||||||
|
if ctx == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int64:
|
||||||
|
if t > 0 {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if t > 0 {
|
||||||
|
return int64(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||||
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
// 当账号状态为错误、禁用、不可调度、处于临时不可调度期间,
|
||||||
// 或请求的模型处于限流状态时,返回 true。
|
// 或请求的模型处于限流状态时,返回 true。
|
||||||
@@ -421,6 +510,10 @@ type GatewayService struct {
|
|||||||
concurrencyService *ConcurrencyService
|
concurrencyService *ConcurrencyService
|
||||||
claudeTokenProvider *ClaudeTokenProvider
|
claudeTokenProvider *ClaudeTokenProvider
|
||||||
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
|
||||||
|
userGroupRateCache *gocache.Cache
|
||||||
|
userGroupRateSF singleflight.Group
|
||||||
|
modelsListCache *gocache.Cache
|
||||||
|
modelsListCacheTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayService creates a new GatewayService
|
// NewGatewayService creates a new GatewayService
|
||||||
@@ -445,6 +538,9 @@ func NewGatewayService(
|
|||||||
sessionLimitCache SessionLimitCache,
|
sessionLimitCache SessionLimitCache,
|
||||||
digestStore *DigestSessionStore,
|
digestStore *DigestSessionStore,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
|
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
|
||||||
|
modelsListTTL := resolveModelsListCacheTTL(cfg)
|
||||||
|
|
||||||
return &GatewayService{
|
return &GatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
groupRepo: groupRepo,
|
groupRepo: groupRepo,
|
||||||
@@ -465,6 +561,9 @@ func NewGatewayService(
|
|||||||
deferredService: deferredService,
|
deferredService: deferredService,
|
||||||
claudeTokenProvider: claudeTokenProvider,
|
claudeTokenProvider: claudeTokenProvider,
|
||||||
sessionLimitCache: sessionLimitCache,
|
sessionLimitCache: sessionLimitCache,
|
||||||
|
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
|
||||||
|
modelsListCache: gocache.New(modelsListTTL, time.Minute),
|
||||||
|
modelsListCacheTTL: modelsListTTL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -937,7 +1036,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
|
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if sessionHash != "" && s.cache != nil {
|
if prefetch := prefetchedStickyAccountIDFromContext(ctx); prefetch > 0 {
|
||||||
|
stickyAccountID = prefetch
|
||||||
|
} else if sessionHash != "" && s.cache != nil {
|
||||||
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||||
stickyAccountID = accountID
|
stickyAccountID = accountID
|
||||||
}
|
}
|
||||||
@@ -1035,6 +1136,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if len(accounts) == 0 {
|
if len(accounts) == 0 {
|
||||||
return nil, errors.New("no available accounts")
|
return nil, errors.New("no available accounts")
|
||||||
}
|
}
|
||||||
|
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||||
|
|
||||||
isExcluded := func(accountID int64) bool {
|
isExcluded := func(accountID int64) bool {
|
||||||
if excludedIDs == nil {
|
if excludedIDs == nil {
|
||||||
@@ -1125,9 +1227,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
|
|
||||||
if len(routingCandidates) > 0 {
|
if len(routingCandidates) > 0 {
|
||||||
// 1.5. 在路由账号范围内检查粘性会话
|
// 1.5. 在路由账号范围内检查粘性会话
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && stickyAccountID > 0 {
|
||||||
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||||
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
|
||||||
// 粘性账号在路由列表中,优先使用
|
// 粘性账号在路由列表中,优先使用
|
||||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||||
if stickyAccount.IsSchedulable() &&
|
if stickyAccount.IsSchedulable() &&
|
||||||
@@ -1273,9 +1374,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
||||||
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
|
if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
accountID := stickyAccountID
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if accountID > 0 && !isExcluded(accountID) {
|
||||||
account, ok := accountByID[accountID]
|
account, ok := accountByID[accountID]
|
||||||
if ok {
|
if ok {
|
||||||
// 检查账户是否需要清理粘性会话绑定
|
// 检查账户是否需要清理粘性会话绑定
|
||||||
@@ -1760,6 +1861,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
|
|||||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type usageLogWindowStatsBatchProvider interface {
|
||||||
|
GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type windowCostPrefetchContextKeyType struct{}
|
||||||
|
|
||||||
|
var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{}
|
||||||
|
|
||||||
|
func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) {
|
||||||
|
if ctx == nil || accountID <= 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64)
|
||||||
|
if !ok || len(m) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
v, exists := m[accountID]
|
||||||
|
return v, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context {
|
||||||
|
if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
accountByID := make(map[int64]*Account)
|
||||||
|
accountIDs := make([]int64, 0, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
account := &accounts[i]
|
||||||
|
if account == nil || !account.IsAnthropicOAuthOrSetupToken() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if account.GetWindowCostLimit() <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
accountByID[account.ID] = account
|
||||||
|
accountIDs = append(accountIDs, account.ID)
|
||||||
|
}
|
||||||
|
if len(accountIDs) == 0 {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
costs := make(map[int64]float64, len(accountIDs))
|
||||||
|
cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs)
|
||||||
|
if err == nil {
|
||||||
|
for accountID, cost := range cacheValues {
|
||||||
|
costs[accountID] = cost
|
||||||
|
}
|
||||||
|
windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues)))
|
||||||
|
} else {
|
||||||
|
windowCostPrefetchErrorTotal.Add(1)
|
||||||
|
logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err)
|
||||||
|
}
|
||||||
|
cacheMissCount := len(accountIDs) - len(costs)
|
||||||
|
if cacheMissCount < 0 {
|
||||||
|
cacheMissCount = 0
|
||||||
|
}
|
||||||
|
windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount))
|
||||||
|
|
||||||
|
missingByStart := make(map[int64][]int64)
|
||||||
|
startTimes := make(map[int64]time.Time)
|
||||||
|
for _, accountID := range accountIDs {
|
||||||
|
if _, ok := costs[accountID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
account := accountByID[accountID]
|
||||||
|
if account == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
startTime := account.GetCurrentWindowStartTime()
|
||||||
|
startKey := startTime.Unix()
|
||||||
|
missingByStart[startKey] = append(missingByStart[startKey], accountID)
|
||||||
|
startTimes[startKey] = startTime
|
||||||
|
}
|
||||||
|
if len(missingByStart) == 0 {
|
||||||
|
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||||
|
}
|
||||||
|
|
||||||
|
batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider)
|
||||||
|
for startKey, ids := range missingByStart {
|
||||||
|
startTime := startTimes[startKey]
|
||||||
|
|
||||||
|
if hasBatch {
|
||||||
|
windowCostPrefetchBatchSQLTotal.Add(1)
|
||||||
|
queryStart := time.Now()
|
||||||
|
statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime)
|
||||||
|
if err == nil {
|
||||||
|
slog.Debug("window_cost_batch_query_ok",
|
||||||
|
"accounts", len(ids),
|
||||||
|
"window_start", startTime.Format(time.RFC3339),
|
||||||
|
"duration_ms", time.Since(queryStart).Milliseconds())
|
||||||
|
for _, accountID := range ids {
|
||||||
|
stats := statsByAccount[accountID]
|
||||||
|
cost := 0.0
|
||||||
|
if stats != nil {
|
||||||
|
cost = stats.StandardCost
|
||||||
|
}
|
||||||
|
costs[accountID] = cost
|
||||||
|
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
windowCostPrefetchErrorTotal.Add(1)
|
||||||
|
logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。
|
||||||
|
windowCostPrefetchFallbackTotal.Add(int64(len(ids)))
|
||||||
|
for _, accountID := range ids {
|
||||||
|
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
|
||||||
|
if err != nil {
|
||||||
|
windowCostPrefetchErrorTotal.Add(1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cost := stats.StandardCost
|
||||||
|
costs[accountID] = cost
|
||||||
|
_ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return context.WithValue(ctx, windowCostPrefetchContextKey, costs)
|
||||||
|
}
|
||||||
|
|
||||||
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
|
||||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||||
// 返回 true 表示可调度,false 表示不可调度
|
// 返回 true 表示可调度,false 表示不可调度
|
||||||
@@ -1776,6 +2000,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
|
|||||||
|
|
||||||
// 尝试从缓存获取窗口费用
|
// 尝试从缓存获取窗口费用
|
||||||
var currentCost float64
|
var currentCost float64
|
||||||
|
if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok {
|
||||||
|
currentCost = cost
|
||||||
|
goto checkSchedulability
|
||||||
|
}
|
||||||
if s.sessionLimitCache != nil {
|
if s.sessionLimitCache != nil {
|
||||||
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
|
||||||
currentCost = cost
|
currentCost = cost
|
||||||
@@ -5264,6 +5492,66 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
|
||||||
|
if s == nil || userID <= 0 || groupID <= 0 {
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%d:%d", userID, groupID)
|
||||||
|
if s.userGroupRateCache != nil {
|
||||||
|
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||||
|
if multiplier, castOK := cached.(float64); castOK {
|
||||||
|
userGroupRateCacheHitTotal.Add(1)
|
||||||
|
return multiplier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.userGroupRateRepo == nil {
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
userGroupRateCacheMissTotal.Add(1)
|
||||||
|
|
||||||
|
value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) {
|
||||||
|
if s.userGroupRateCache != nil {
|
||||||
|
if cached, ok := s.userGroupRateCache.Get(key); ok {
|
||||||
|
if multiplier, castOK := cached.(float64); castOK {
|
||||||
|
userGroupRateCacheHitTotal.Add(1)
|
||||||
|
return multiplier, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userGroupRateCacheLoadTotal.Add(1)
|
||||||
|
userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID)
|
||||||
|
if repoErr != nil {
|
||||||
|
return nil, repoErr
|
||||||
|
}
|
||||||
|
multiplier := groupDefaultMultiplier
|
||||||
|
if userRate != nil {
|
||||||
|
multiplier = *userRate
|
||||||
|
}
|
||||||
|
if s.userGroupRateCache != nil {
|
||||||
|
s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg))
|
||||||
|
}
|
||||||
|
return multiplier, nil
|
||||||
|
})
|
||||||
|
if shared {
|
||||||
|
userGroupRateCacheSFSharedTotal.Add(1)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
userGroupRateCacheFallbackTotal.Add(1)
|
||||||
|
logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplier, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
userGroupRateCacheFallbackTotal.Add(1)
|
||||||
|
return groupDefaultMultiplier
|
||||||
|
}
|
||||||
|
return multiplier
|
||||||
|
}
|
||||||
|
|
||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
@@ -5307,16 +5595,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||||
multiplier := s.cfg.Default.RateMultiplier
|
multiplier := 1.0
|
||||||
|
if s.cfg != nil {
|
||||||
|
multiplier = s.cfg.Default.RateMultiplier
|
||||||
|
}
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||||
multiplier = apiKey.Group.RateMultiplier
|
groupDefault := apiKey.Group.RateMultiplier
|
||||||
|
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||||
// 检查用户专属倍率
|
|
||||||
if s.userGroupRateRepo != nil {
|
|
||||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
|
||||||
multiplier = *userRate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
@@ -5522,16 +5807,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||||
multiplier := s.cfg.Default.RateMultiplier
|
multiplier := 1.0
|
||||||
|
if s.cfg != nil {
|
||||||
|
multiplier = s.cfg.Default.RateMultiplier
|
||||||
|
}
|
||||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||||
multiplier = apiKey.Group.RateMultiplier
|
groupDefault := apiKey.Group.RateMultiplier
|
||||||
|
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
|
||||||
// 检查用户专属倍率
|
|
||||||
if s.userGroupRateRepo != nil {
|
|
||||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
|
||||||
multiplier = *userRate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var cost *CostBreakdown
|
var cost *CostBreakdown
|
||||||
@@ -6145,6 +6427,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
|||||||
// GetAvailableModels returns the list of models available for a group
|
// GetAvailableModels returns the list of models available for a group
|
||||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||||
|
cacheKey := modelsListCacheKey(groupID, platform)
|
||||||
|
if s.modelsListCache != nil {
|
||||||
|
if cached, found := s.modelsListCache.Get(cacheKey); found {
|
||||||
|
if models, ok := cached.([]string); ok {
|
||||||
|
modelsListCacheHitTotal.Add(1)
|
||||||
|
return cloneStringSlice(models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
modelsListCacheMissTotal.Add(1)
|
||||||
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -6185,6 +6478,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
|||||||
|
|
||||||
// If no account has model_mapping, return nil (use default)
|
// If no account has model_mapping, return nil (use default)
|
||||||
if !hasAnyMapping {
|
if !hasAnyMapping {
|
||||||
|
if s.modelsListCache != nil {
|
||||||
|
s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL)
|
||||||
|
modelsListCacheStoreTotal.Add(1)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -6193,8 +6490,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
|||||||
for model := range modelSet {
|
for model := range modelSet {
|
||||||
models = append(models, model)
|
models = append(models, model)
|
||||||
}
|
}
|
||||||
|
sort.Strings(models)
|
||||||
|
|
||||||
return models
|
if s.modelsListCache != nil {
|
||||||
|
s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL)
|
||||||
|
modelsListCacheStoreTotal.Add(1)
|
||||||
|
}
|
||||||
|
return cloneStringSlice(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) {
|
||||||
|
if s == nil || s.modelsListCache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedPlatform := strings.TrimSpace(platform)
|
||||||
|
// 完整匹配时精准失效;否则按维度批量失效。
|
||||||
|
if groupID != nil && normalizedPlatform != "" {
|
||||||
|
s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
targetGroup := derefGroupID(groupID)
|
||||||
|
for key := range s.modelsListCache.Items() {
|
||||||
|
parts := strings.SplitN(key, "|", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if parseErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if groupID != nil && groupPart != targetGroup {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if normalizedPlatform != "" && parts[1] != normalizedPlatform {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.modelsListCache.Delete(key)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconcileCachedTokens 兼容 Kimi 等上游:
|
// reconcileCachedTokens 兼容 Kimi 等上游:
|
||||||
|
|||||||
@@ -20,6 +20,22 @@ const (
|
|||||||
opsMaxStoredErrorBodyBytes = 20 * 1024
|
opsMaxStoredErrorBodyBytes = 20 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。
|
||||||
|
// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。
|
||||||
|
func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes)
|
||||||
|
if sanitized != "" {
|
||||||
|
out := sanitized
|
||||||
|
requestBodyJSON = &out
|
||||||
|
}
|
||||||
|
n := bytesLen
|
||||||
|
requestBodyBytes = &n
|
||||||
|
return requestBodyJSON, truncated, requestBodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
// OpsService provides ingestion and query APIs for the Ops monitoring module.
|
// OpsService provides ingestion and query APIs for the Ops monitoring module.
|
||||||
type OpsService struct {
|
type OpsService struct {
|
||||||
opsRepo OpsRepository
|
opsRepo OpsRepository
|
||||||
@@ -132,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
|
|||||||
|
|
||||||
// Sanitize + trim request body (errors only).
|
// Sanitize + trim request body (errors only).
|
||||||
if len(rawRequestBody) > 0 {
|
if len(rawRequestBody) > 0 {
|
||||||
sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes)
|
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody)
|
||||||
if sanitized != "" {
|
|
||||||
entry.RequestBodyJSON = &sanitized
|
|
||||||
}
|
|
||||||
entry.RequestBodyTruncated = truncated
|
|
||||||
entry.RequestBodyBytes = &bytesLen
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sanitize + truncate error_body to avoid storing sensitive data.
|
// Sanitize + truncate error_body to avoid storing sensitive data.
|
||||||
|
|||||||
60
backend/internal/service/ops_service_prepare_queue_test.go
Normal file
60
backend/internal/service/ops_service_prepare_queue_test.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) {
|
||||||
|
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil)
|
||||||
|
require.Nil(t, requestBodyJSON)
|
||||||
|
require.False(t, truncated)
|
||||||
|
require.Nil(t, requestBodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) {
|
||||||
|
raw := []byte("{invalid-json")
|
||||||
|
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||||
|
require.Nil(t, requestBodyJSON)
|
||||||
|
require.False(t, truncated)
|
||||||
|
require.NotNil(t, requestBodyBytes)
|
||||||
|
require.Equal(t, len(raw), *requestBodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) {
|
||||||
|
raw := []byte(`{
|
||||||
|
"model":"claude-3-5-sonnet-20241022",
|
||||||
|
"api_key":"sk-test-123",
|
||||||
|
"headers":{"authorization":"Bearer secret-token"},
|
||||||
|
"messages":[{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||||
|
require.NotNil(t, requestBodyJSON)
|
||||||
|
require.NotNil(t, requestBodyBytes)
|
||||||
|
require.False(t, truncated)
|
||||||
|
require.Equal(t, len(raw), *requestBodyBytes)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body))
|
||||||
|
require.Equal(t, "[REDACTED]", body["api_key"])
|
||||||
|
headers, ok := body["headers"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "[REDACTED]", headers["authorization"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) {
|
||||||
|
largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2)
|
||||||
|
raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`)
|
||||||
|
|
||||||
|
requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
|
||||||
|
require.NotNil(t, requestBodyJSON)
|
||||||
|
require.NotNil(t, requestBodyBytes)
|
||||||
|
require.True(t, truncated)
|
||||||
|
require.Equal(t, len(raw), *requestBodyBytes)
|
||||||
|
require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes)
|
||||||
|
require.Contains(t, *requestBodyJSON, "request_body_truncated")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user