Merge remote-tracking branch 'upstream/main'
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

This commit is contained in:
huangzhenpc
2026-01-18 23:46:19 +08:00
52 changed files with 2922 additions and 199 deletions

View File

@@ -576,6 +576,25 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
}
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
func (a *Account) IsTLSFingerprintEnabled() bool {
// 仅支持 Anthropic OAuth/SetupToken 账号
if !a.IsAnthropicOAuthOrSetupToken() {
return false
}
if a.Extra == nil {
return false
}
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
if enabled, ok := v.(bool); ok {
return enabled
}
}
return false
}
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {

View File

@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}

View File

@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "gemini-2.5-flash",
},
{
name: "Gemini透传 - gemini-1.5-pro",
requestedModel: "gemini-1.5-pro",
name: "Gemini透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-1.5-pro",
expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",

View File

@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCacheL1 = cache
}
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
}
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {

View File

@@ -65,6 +65,10 @@ type APIKeyCache interface {
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力

View File

@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{

View File

@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
return nil
}
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1

View File

@@ -44,6 +44,13 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
return v == "1" || v == "true" || v == "yes" || v == "on"
}
// debugLog prints log only in non-release mode.
func debugLog(format string, v ...any) {
if gin.Mode() != gin.ReleaseMode {
log.Printf(format, v...)
}
}
func shortSessionHash(sessionHash string) string {
if sessionHash == "" {
return ""
@@ -410,11 +417,17 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
excludedIDsList = append(excludedIDsList, id)
}
debugLog("[AccountScheduling] Starting account selection: groupID=%v model=%s session=%s excludedIDs=%v",
derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), excludedIDsList)
cfg := s.schedulingConfig()
// 提取会话 UUID用于会话数量限制
sessionUUID := extractSessionUUID(metadataUserID)
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
@@ -440,41 +453,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
// 复制排除列表,用于会话限制拒绝时的重试
localExcluded := make(map[int64]struct{})
for k, v := range excludedIDs {
localExcluded[k] = v
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
for {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 对于等待计划的情况,也需要先检查会话限制
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
localExcluded[account.ID] = struct{}{}
continue
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
@@ -590,7 +625,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
@@ -608,15 +643,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
@@ -677,7 +717,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -695,20 +735,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
}
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
acc := routingAvailable[0].account
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
// 遍历找到第一个满足会话限制的账号
for _, item := range routingAvailable {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
continue // 会话限制已满,尝试下一个
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
@@ -728,7 +774,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
@@ -742,15 +788,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
@@ -799,7 +850,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
@@ -849,7 +900,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -869,6 +920,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
@@ -882,7 +937,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -890,7 +945,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
@@ -1047,7 +1102,16 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
debugLog("[AccountScheduling] listSchedulableAccounts (snapshot): groupID=%v platform=%s useMixed=%v count=%d",
derefGroupID(groupID), platform, useMixed, len(accounts))
for _, acc := range accounts {
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
}
}
return accounts, useMixed, err
}
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
@@ -1060,6 +1124,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
@@ -1069,6 +1134,12 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
}
filtered = append(filtered, acc)
}
debugLog("[AccountScheduling] listSchedulableAccounts (mixed): groupID=%v platform=%s rawCount=%d filteredCount=%d",
derefGroupID(groupID), platform, len(accounts), len(filtered))
for _, acc := range filtered {
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
@@ -1083,8 +1154,15 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
return nil, useMixed, err
}
debugLog("[AccountScheduling] listSchedulableAccounts (single): groupID=%v platform=%s count=%d",
derefGroupID(groupID), platform, len(accounts))
for _, acc := range accounts {
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
}
return accounts, useMixed, nil
}
@@ -1188,15 +1266,16 @@ checkSchedulability:
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash
// 返回 true 表示允许在限制内或会话已存在false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionUUID == "" {
if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID
}
@@ -1206,7 +1285,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
@@ -1214,18 +1293,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
return allowed
}
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
// 格式: user_{64位hex}_account__session_{uuid}
func extractSessionUUID(metadataUserID string) string {
if metadataUserID == "" {
return ""
}
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
return match[1]
}
return ""
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
@@ -2088,6 +2155,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
proxyURL = account.Proxy.URL()
}
// 调试日志:记录即将转发的账号信息
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// 重试循环
var resp *http.Response
retryStart := time.Now()
@@ -2102,7 +2173,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
@@ -2176,7 +2247,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
@@ -2208,7 +2279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil {
resp = retryResp2
break
@@ -2323,6 +2394,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印重试耗尽后的错误响应
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -2350,6 +2425,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
@@ -2700,6 +2779,10 @@ func extractUpstreamErrorMessage(body []byte) string {
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
@@ -3408,7 +3491,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
@@ -3430,7 +3513,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)

View File

@@ -599,7 +599,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,

View File

@@ -10,6 +10,7 @@ import "net/http"
// - 支持可选代理配置
// - 支持账户级连接池隔离
// - 实现类负责连接池管理和复用
// - 支持可选的 TLS 指纹伪装
type HTTPUpstream interface {
// Do 执行 HTTP 请求
//
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
// - 调用方必须关闭 resp.Body否则会导致连接泄漏
// - 响应体可能已被包装以跟踪请求生命周期
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
//
// 参数:
// - req: HTTP 请求对象,由调用方构建
// - proxyURL: 代理服务器地址,空字符串表示直连
// - accountID: 账户 ID用于连接池隔离和 TLS 指纹模板选择
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
//
// 返回:
// - *http.Response: HTTP 响应,调用方必须关闭 Body
// - error: 请求错误(网络错误、超时等)
//
// TLS 指纹说明:
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
// - 如果 enableTLSFingerprint=false行为与 Do 方法相同
//
// 注意:
// - 调用方必须关闭 resp.Body否则会导致连接泄漏
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
}

View File

@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
modified := false
for idx, tool := range tools {
validTools := make([]any, 0, len(tools))
for _, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
// Keep unknown structure as-is to avoid breaking upstream behavior.
validTools = append(validTools, tool)
continue
}
toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" {
toolType = strings.TrimSpace(toolType)
if toolType != "function" {
validTools = append(validTools, toolMap)
continue
}
function, ok := toolMap["function"].(map[string]any)
if !ok {
// OpenAI Responses-style tools use top-level name/parameters.
if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" {
validTools = append(validTools, toolMap)
continue
}
// ChatCompletions-style tools use {type:"function", function:{...}}.
functionValue, hasFunction := toolMap["function"]
function, ok := functionValue.(map[string]any)
if !hasFunction || functionValue == nil || !ok || function == nil {
// Drop invalid function tools.
modified = true
continue
}
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
}
tools[idx] = toolMap
validTools = append(validTools, toolMap)
}
if modified {
reqBody["tools"] = tools
reqBody["tools"] = validTools
}
return modified

View File

@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
require.False(t, hasID)
}
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
setupCodexCache(t)
reqBody := map[string]any{
"model": "gpt-5.1",
"tools": []any{
map[string]any{
"type": "function",
"name": "bash",
"description": "desc",
"parameters": map[string]any{"type": "object"},
},
map[string]any{
"type": "function",
"function": nil,
},
},
}
applyCodexOAuthTransform(reqBody)
tools, ok := reqBody["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 1)
first, ok := tools[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function", first["type"])
require.Equal(t, "bash", first["name"])
}
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)

View File

@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
}
}
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
sessionID := c.GetHeader("session_id")
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
//
// Priority:
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
if c == nil {
return ""
}
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
if sessionID == "" {
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
}
if sessionID == "" && reqBody != nil {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
sessionID = strings.TrimSpace(v)
}
}
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}

View File

@@ -49,6 +49,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
return out, nil
}
func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
svc := &OpenAIGatewayService{}
// 1) session_id header wins
c.Request.Header.Set("session_id", "sess-123")
c.Request.Header.Set("conversation_id", "conv-456")
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h1 == "" {
t.Fatalf("expected non-empty hash")
}
// 2) conversation_id used when session_id absent
c.Request.Header.Del("session_id")
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h2 == "" {
t.Fatalf("expected non-empty hash")
}
if h1 == h2 {
t.Fatalf("expected different hashes for different keys")
}
// 3) prompt_cache_key used when both headers absent
c.Request.Header.Del("conversation_id")
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
if h3 == "" {
t.Fatalf("expected non-empty hash")
}
if h2 == h3 {
t.Fatalf("expected different hashes for different keys")
}
// 4) empty when no signals
h4 := svc.GenerateSessionHash(c, map[string]any{})
if h4 != "" {
t.Fatalf("expected empty hash when no signals")
}
}
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)

View File

@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
"executeBash": "bash",
"exec_bash": "bash",
"execBash": "bash",
// Some clients output generic fetch names.
"fetch": "webfetch",
"web_fetch": "webfetch",
"webFetch": "webfetch",
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
// 移除 workdir 参数OpenCode 不支持)
if _, exists := argsMap["workdir"]; exists {
delete(argsMap, "workdir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
}
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
// OpenCode bash 支持 workdir有些来源会输出 work_dir。
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
if workDir, exists := argsMap["work_dir"]; exists {
argsMap["workdir"] = workDir
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
}
} else {
if _, exists := argsMap["work_dir"]; exists {
delete(argsMap, "work_dir")
corrected = true
log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
}
}
case "edit":
// OpenCode edit 使用 old_string/new_stringCodex 可能使用其他名称
// 这里可以添加参数名称的映射逻辑
if _, exists := argsMap["file_path"]; !exists {
if path, exists := argsMap["path"]; exists {
argsMap["file_path"] = path
// OpenCode edit 参数为 filePath/oldString/newStringcamelCase
if _, exists := argsMap["filePath"]; !exists {
if filePath, exists := argsMap["file_path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file_path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["path"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "path")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
} else if filePath, exists := argsMap["file"]; exists {
argsMap["filePath"] = filePath
delete(argsMap, "file")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
}
}
if _, exists := argsMap["oldString"]; !exists {
if oldString, exists := argsMap["old_string"]; exists {
argsMap["oldString"] = oldString
delete(argsMap, "old_string")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
}
}
if _, exists := argsMap["newString"]; !exists {
if newString, exists := argsMap["new_string"]; exists {
argsMap["newString"] = newString
delete(argsMap, "new_string")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
}
}
if _, exists := argsMap["replaceAll"]; !exists {
if replaceAll, exists := argsMap["replace_all"]; exists {
argsMap["replaceAll"] = replaceAll
delete(argsMap, "replace_all")
corrected = true
log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
}
}

View File

@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{
{
name: "remove workdir from bash tool",
name: "rename work_dir to workdir in bash tool",
input: `{
"tool_calls": [{
"function": {
"name": "bash",
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
"arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}"
}
}]
}`,
expected: map[string]bool{
"command": true,
"workdir": false,
"command": true,
"workdir": true,
"work_dir": false,
},
},
{
name: "rename path to file_path in edit tool",
name: "rename snake_case edit params to camelCase",
input: `{
"tool_calls": [{
"function": {
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
}]
}`,
expected: map[string]bool{
"file_path": true,
"filePath": true,
"path": false,
"old_string": true,
"new_string": true,
"oldString": true,
"old_string": false,
"newString": true,
"new_string": false,
},
},
}

View File

@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func normalizeModelNameForPricing(model string) string {
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-1.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
// - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"database/sql"
"time"
@@ -189,6 +190,8 @@ func ProvideOpsScheduledReportService(
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
// Start Pub/Sub subscriber for L1 cache invalidation across instances
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
return apiKeyService
}