Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -576,6 +576,25 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
|
||||
func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
// 仅支持 Anthropic OAuth/SetupToken 账号
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
|
||||
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
|
||||
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
|
||||
|
||||
// Gemini 前缀透传
|
||||
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
|
||||
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
|
||||
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
|
||||
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
|
||||
|
||||
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-1.5-pro",
|
||||
requestedModel: "gemini-1.5-pro",
|
||||
name: "Gemini透传 - gemini-2.5-pro",
|
||||
requestedModel: "gemini-2.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-1.5-pro",
|
||||
expected: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
|
||||
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
|
||||
// This should be called after the service is fully initialized.
|
||||
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
|
||||
if s.cache == nil || s.authCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}); err != nil {
|
||||
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
||||
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
// Publish invalidation message to other instances
|
||||
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
|
||||
@@ -65,6 +65,10 @@ type APIKeyCache interface {
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
|
||||
// Pub/Sub for L1 cache invalidation across instances
|
||||
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
|
||||
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
|
||||
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
|
||||
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
|
||||
@@ -44,6 +44,13 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
// debugLog prints log only in non-release mode.
|
||||
func debugLog(format string, v ...any) {
|
||||
if gin.Mode() != gin.ReleaseMode {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func shortSessionHash(sessionHash string) string {
|
||||
if sessionHash == "" {
|
||||
return ""
|
||||
@@ -410,11 +417,17 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
excludedIDsList = append(excludedIDsList, id)
|
||||
}
|
||||
debugLog("[AccountScheduling] Starting account selection: groupID=%v model=%s session=%s excludedIDs=%v",
|
||||
derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), excludedIDsList)
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
// 提取会话 UUID(用于会话数量限制)
|
||||
sessionUUID := extractSessionUUID(metadataUserID)
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
@@ -440,41 +453,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 复制排除列表,用于会话限制拒绝时的重试
|
||||
localExcluded := make(map[int64]struct{})
|
||||
for k, v := range excludedIDs {
|
||||
localExcluded[k] = v
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
|
||||
for {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
localExcluded[account.ID] = struct{}{} // 排除此账号
|
||||
continue // 重新选择
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 对于等待计划的情况,也需要先检查会话限制
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
localExcluded[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
@@ -590,7 +625,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
@@ -608,15 +643,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
@@ -677,7 +717,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -695,20 +735,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||
acc := routingAvailable[0].account
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
|
||||
// 遍历找到第一个满足会话限制的账号
|
||||
for _, item := range routingAvailable {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: item.account.ID,
|
||||
MaxConcurrency: item.account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
@@ -728,7 +774,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
@@ -742,15 +788,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -799,7 +850,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -849,7 +900,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -869,6 +920,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
for _, acc := range candidates {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个账号
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
@@ -882,7 +937,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -890,7 +945,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -1047,7 +1102,16 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err == nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (snapshot): groupID=%v platform=%s useMixed=%v count=%d",
|
||||
derefGroupID(groupID), platform, useMixed, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
}
|
||||
return accounts, useMixed, err
|
||||
}
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
@@ -1060,6 +1124,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
@@ -1069,6 +1134,12 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (mixed): groupID=%v platform=%s rawCount=%d filteredCount=%d",
|
||||
derefGroupID(groupID), platform, len(accounts), len(filtered))
|
||||
for _, acc := range filtered {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1083,8 +1154,15 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (single): groupID=%v platform=%s count=%d",
|
||||
derefGroupID(groupID), platform, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1188,15 +1266,16 @@ checkSchedulability:
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
maxSessions := account.GetMaxSessions()
|
||||
if maxSessions <= 0 || sessionUUID == "" {
|
||||
if maxSessions <= 0 || sessionID == "" {
|
||||
return true // 未启用会话限制或无会话ID
|
||||
}
|
||||
|
||||
@@ -1206,7 +1285,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
|
||||
if err != nil {
|
||||
// 失败开放:缓存错误时允许通过
|
||||
return true
|
||||
@@ -1214,18 +1293,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
return allowed
|
||||
}
|
||||
|
||||
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||
// 格式: user_{64位hex}_account__session_{uuid}
|
||||
func extractSessionUUID(metadataUserID string) string {
|
||||
if metadataUserID == "" {
|
||||
return ""
|
||||
}
|
||||
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
@@ -2088,6 +2155,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
@@ -2102,7 +2173,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -2176,7 +2247,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
@@ -2208,7 +2279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
resp = retryResp2
|
||||
break
|
||||
@@ -2323,6 +2394,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印重试耗尽后的错误响应
|
||||
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2350,6 +2425,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2700,6 +2779,10 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
@@ -3408,7 +3491,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
@@ -3430,7 +3513,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
|
||||
@@ -599,7 +599,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
|
||||
@@ -10,6 +10,7 @@ import "net/http"
|
||||
// - 支持可选代理配置
|
||||
// - 支持账户级连接池隔离
|
||||
// - 实现类负责连接池管理和复用
|
||||
// - 支持可选的 TLS 指纹伪装
|
||||
type HTTPUpstream interface {
|
||||
// Do 执行 HTTP 请求
|
||||
//
|
||||
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - 响应体可能已被包装以跟踪请求生命周期
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象,由调用方构建
|
||||
// - proxyURL: 代理服务器地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应,调用方必须关闭 Body
|
||||
// - error: 请求错误(网络错误、超时等)
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
|
||||
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
|
||||
}
|
||||
|
||||
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
}
|
||||
|
||||
modified := false
|
||||
for idx, tool := range tools {
|
||||
validTools := make([]any, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
// Keep unknown structure as-is to avoid breaking upstream behavior.
|
||||
validTools = append(validTools, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
toolType, _ := toolMap["type"].(string)
|
||||
if strings.TrimSpace(toolType) != "function" {
|
||||
toolType = strings.TrimSpace(toolType)
|
||||
if toolType != "function" {
|
||||
validTools = append(validTools, toolMap)
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := toolMap["function"].(map[string]any)
|
||||
if !ok {
|
||||
// OpenAI Responses-style tools use top-level name/parameters.
|
||||
if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
validTools = append(validTools, toolMap)
|
||||
continue
|
||||
}
|
||||
|
||||
// ChatCompletions-style tools use {type:"function", function:{...}}.
|
||||
functionValue, hasFunction := toolMap["function"]
|
||||
function, ok := functionValue.(map[string]any)
|
||||
if !hasFunction || functionValue == nil || !ok || function == nil {
|
||||
// Drop invalid function tools.
|
||||
modified = true
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
tools[idx] = toolMap
|
||||
validTools = append(validTools, toolMap)
|
||||
}
|
||||
|
||||
if modified {
|
||||
reqBody["tools"] = tools
|
||||
reqBody["tools"] = validTools
|
||||
}
|
||||
|
||||
return modified
|
||||
|
||||
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"name": "bash",
|
||||
"description": "desc",
|
||||
"parameters": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
first, ok := tools[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function", first["type"])
|
||||
require.Equal(t, "bash", first["name"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
sessionID := c.GetHeader("session_id")
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
// 1. Header: session_id
|
||||
// 2. Header: conversation_id
|
||||
// 3. Body: prompt_cache_key (opencode)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && reqBody != nil {
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
sessionID = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -49,6 +49,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
// 1) session_id header wins
|
||||
c.Request.Header.Set("session_id", "sess-123")
|
||||
c.Request.Header.Set("conversation_id", "conv-456")
|
||||
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h1 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
|
||||
// 2) conversation_id used when session_id absent
|
||||
c.Request.Header.Del("session_id")
|
||||
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h2 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
if h1 == h2 {
|
||||
t.Fatalf("expected different hashes for different keys")
|
||||
}
|
||||
|
||||
// 3) prompt_cache_key used when both headers absent
|
||||
c.Request.Header.Del("conversation_id")
|
||||
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h3 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
if h2 == h3 {
|
||||
t.Fatalf("expected different hashes for different keys")
|
||||
}
|
||||
|
||||
// 4) empty when no signals
|
||||
h4 := svc.GenerateSessionHash(c, map[string]any{})
|
||||
if h4 != "" {
|
||||
t.Fatalf("expected empty hash when no signals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
|
||||
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
|
||||
"executeBash": "bash",
|
||||
"exec_bash": "bash",
|
||||
"execBash": "bash",
|
||||
|
||||
// Some clients output generic fetch names.
|
||||
"fetch": "webfetch",
|
||||
"web_fetch": "webfetch",
|
||||
"webFetch": "webfetch",
|
||||
}
|
||||
|
||||
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// 移除 workdir 参数(OpenCode 不支持)
|
||||
if _, exists := argsMap["workdir"]; exists {
|
||||
delete(argsMap, "workdir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||
}
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
|
||||
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
|
||||
if workDir, exists := argsMap["work_dir"]; exists {
|
||||
argsMap["workdir"] = workDir
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
}
|
||||
} else {
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
}
|
||||
}
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||
// 这里可以添加参数名称的映射逻辑
|
||||
if _, exists := argsMap["file_path"]; !exists {
|
||||
if path, exists := argsMap["path"]; exists {
|
||||
argsMap["file_path"] = path
|
||||
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
|
||||
if _, exists := argsMap["filePath"]; !exists {
|
||||
if filePath, exists := argsMap["file_path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file_path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["file"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["oldString"]; !exists {
|
||||
if oldString, exists := argsMap["old_string"]; exists {
|
||||
argsMap["oldString"] = oldString
|
||||
delete(argsMap, "old_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["newString"]; !exists {
|
||||
if newString, exists := argsMap["new_string"]; exists {
|
||||
argsMap["newString"] = newString
|
||||
delete(argsMap, "new_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["replaceAll"]; !exists {
|
||||
if replaceAll, exists := argsMap["replace_all"]; exists {
|
||||
argsMap["replaceAll"] = replaceAll
|
||||
delete(argsMap, "replace_all")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
|
||||
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||
}{
|
||||
{
|
||||
name: "remove workdir from bash tool",
|
||||
name: "rename work_dir to workdir in bash tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||
"arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"command": true,
|
||||
"workdir": false,
|
||||
"command": true,
|
||||
"workdir": true,
|
||||
"work_dir": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rename path to file_path in edit tool",
|
||||
name: "rename snake_case edit params to camelCase",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"file_path": true,
|
||||
"filePath": true,
|
||||
"path": false,
|
||||
"old_string": true,
|
||||
"new_string": true,
|
||||
"oldString": true,
|
||||
"old_string": false,
|
||||
"newString": true,
|
||||
"new_string": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
|
||||
func normalizeModelNameForPricing(model string) string {
|
||||
// Common Gemini/VertexAI forms:
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-1.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
@@ -189,6 +190,8 @@ func ProvideOpsScheduledReportService(
|
||||
|
||||
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||
// Start Pub/Sub subscriber for L1 cache invalidation across instances
|
||||
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user