Merge upstream/main: sync latest updates
- feat(gateway): aggregate all text chunks in non-streaming Gemini responses - feat(gateway): add SUGGESTION MODE request interception - feat(oauth): support Anthropic Team accounts with sk authorization - fix(oauth): update Anthropic OAuth parameters to sync with latest client - feat: add PromoCodeEnabled setting (default: true) - resolved conflict: keep TianShuAPI site name while adding PromoCode feature Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredentialAsInt64 解析凭证中的 int64 字段
|
||||
// 用于读取 _token_version 等内部字段
|
||||
func (a *Account) GetCredentialAsInt64(key string) int64 {
|
||||
if a == nil || a.Credentials == nil {
|
||||
return 0
|
||||
}
|
||||
val, ok := a.Credentials[key]
|
||||
if !ok || val == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case float64:
|
||||
return int64(v)
|
||||
case int:
|
||||
return int64(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return i
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
|
||||
@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 清理 Schema
|
||||
if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
|
||||
injectedBody = cleanedBody
|
||||
log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
|
||||
} else {
|
||||
log.Printf("[Antigravity] Failed to clean schema: %v", err)
|
||||
}
|
||||
|
||||
// 包装请求
|
||||
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
|
||||
if err != nil {
|
||||
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
||||
if u := extractGeminiUsage(parsed); u != nil {
|
||||
usage = u
|
||||
}
|
||||
// Check for MALFORMED_FUNCTION_CALL
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
||||
usage = u
|
||||
}
|
||||
|
||||
// Check for MALFORMED_FUNCTION_CALL
|
||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if cand, ok := candidates[0].(map[string]any); ok {
|
||||
if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
|
||||
log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
|
||||
if content, ok := cand["content"]; ok {
|
||||
if b, err := json.Marshal(content); err == nil {
|
||||
log.Printf("[Antigravity] Malformed content: %s", string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保留最后一个有 parts 的响应
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
|
||||
return result, existingParts, setParts
|
||||
}
|
||||
|
||||
// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
|
||||
// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
|
||||
// 保持原始顺序,只合并连续的普通 text parts
|
||||
func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
|
||||
if len(collectedParts) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
result, _, setParts := getOrCreateGeminiParts(response)
|
||||
|
||||
// 合并策略:
|
||||
// 1. 保持原始顺序
|
||||
// 2. 连续的普通 text parts 合并为一个
|
||||
// 3. thinking、functionCall、inlineData 等保持原样
|
||||
var mergedParts []any
|
||||
var textBuffer strings.Builder
|
||||
|
||||
flushTextBuffer := func() {
|
||||
if textBuffer.Len() > 0 {
|
||||
mergedParts = append(mergedParts, map[string]any{
|
||||
"text": textBuffer.String(),
|
||||
})
|
||||
textBuffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
for _, part := range collectedParts {
|
||||
// 检查是否是普通 text part
|
||||
if text, ok := part["text"].(string); ok {
|
||||
// 检查是否有 thought 标记
|
||||
if thought, _ := part["thought"].(bool); thought {
|
||||
// thinking part,先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
} else {
|
||||
// 普通 text,累积到 buffer
|
||||
_, _ = textBuffer.WriteString(text)
|
||||
}
|
||||
} else {
|
||||
// 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
|
||||
flushTextBuffer()
|
||||
mergedParts = append(mergedParts, part)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新剩余的 text
|
||||
flushTextBuffer()
|
||||
|
||||
setParts(mergedParts)
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
|
||||
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
|
||||
if len(imageParts) == 0 {
|
||||
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
var firstTokenMs *int
|
||||
var last map[string]any
|
||||
var lastWithParts map[string]any
|
||||
var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
|
||||
|
||||
type scanEvent struct {
|
||||
line string
|
||||
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
|
||||
|
||||
last = parsed
|
||||
|
||||
// 保留最后一个有 parts 的响应
|
||||
// 保留最后一个有 parts 的响应,并收集所有 parts
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
|
||||
// 收集所有 parts(text、thinking、functionCall、inlineData 等)
|
||||
collectedParts = append(collectedParts, parts...)
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
@@ -2252,6 +2343,11 @@ returnResponse:
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
|
||||
}
|
||||
|
||||
// 将收集的所有 parts 合并到最终响应中
|
||||
if len(collectedParts) > 0 {
|
||||
finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
|
||||
}
|
||||
|
||||
// 序列化为 JSON(Gemini 格式)
|
||||
geminiBody, err := json.Marshal(finalResponse)
|
||||
if err != nil {
|
||||
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
|
||||
modelLower == "gemini-2.5-flash-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||||
}
|
||||
|
||||
// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
|
||||
func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modified := false
|
||||
|
||||
// 1. 清理 Tools
|
||||
if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
|
||||
for _, t := range tools {
|
||||
toolMap, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// function_declarations (snake_case) or functionDeclarations (camelCase)
|
||||
var funcs []any
|
||||
if f, ok := toolMap["functionDeclarations"].([]any); ok {
|
||||
funcs = f
|
||||
} else if f, ok := toolMap["function_declarations"].([]any); ok {
|
||||
funcs = f
|
||||
}
|
||||
|
||||
if len(funcs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, f := range funcs {
|
||||
funcMap, ok := f.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if params, ok := funcMap["parameters"].(map[string]any); ok {
|
||||
antigravity.DeepCleanUndefined(params)
|
||||
cleaned := antigravity.CleanJSONSchema(params)
|
||||
funcMap["parameters"] = cleaned
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
@@ -94,14 +94,14 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
|
||||
var handleErrorCalled bool
|
||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
handleErrorCalled = true
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -101,21 +102,32 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
|
||||
@@ -153,8 +153,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
|
||||
@@ -181,26 +181,37 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -99,11 +99,24 @@ var allowedHeaders = map[string]bool{
|
||||
"content-type": true,
|
||||
}
|
||||
|
||||
// GatewayCache defines cache operations for gateway service
|
||||
// GatewayCache 定义网关服务的缓存操作接口。
|
||||
// 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。
|
||||
//
|
||||
// GatewayCache defines cache operations for gateway service.
|
||||
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
|
||||
type GatewayCache interface {
|
||||
// GetSessionAccountID 获取粘性会话绑定的账号 ID
|
||||
// Get the account ID bound to a sticky session
|
||||
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||
// SetSessionAccountID 设置粘性会话与账号的绑定关系
|
||||
// Set the binding between sticky session and account
|
||||
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||
// RefreshSessionTTL 刷新粘性会话的过期时间
|
||||
// Refresh the expiration time of a sticky session
|
||||
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
|
||||
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
|
||||
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
|
||||
}
|
||||
|
||||
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||
@@ -114,6 +127,28 @@ func derefGroupID(groupID *int64) int64 {
|
||||
return *groupID
|
||||
}
|
||||
|
||||
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
|
||||
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
|
||||
// 这确保后续请求不会继续使用不可用的账号。
|
||||
//
|
||||
// shouldClearStickySession checks if an account is in an unschedulable state
|
||||
// and the sticky session binding should be cleared.
|
||||
// Returns true when account status is error/disabled, schedulable is false,
|
||||
// or within temporary unschedulable period.
|
||||
// This ensures subsequent requests won't continue using unavailable accounts.
|
||||
func shouldClearStickySession(account *Account) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
|
||||
return true
|
||||
}
|
||||
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
AccountID int64
|
||||
MaxConcurrency int
|
||||
@@ -658,6 +693,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -764,41 +801,52 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
// Check if the account needs sticky session cleanup
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1418,14 +1466,20 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1515,11 +1569,17 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1619,15 +1679,21 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1718,12 +1784,18 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,70 +82,23 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
platform = PlatformGemini
|
||||
// 1. 确定目标平台和调度模式
|
||||
// Determine target platform and scheduling mode
|
||||
platform, useMixedScheduling, hasForcePlatform, err := s.resolvePlatformAndSchedulingMode(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
// 注意:强制平台模式不走混合调度
|
||||
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||
|
||||
cacheKey := "gemini:" + sessionHash
|
||||
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
valid := false
|
||||
if account.Platform == platform {
|
||||
valid = true
|
||||
} else if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
valid = true
|
||||
}
|
||||
if valid {
|
||||
usable := true
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
usable = false
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 2. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs, platform, useMixedScheduling); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// 3. 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||
// Query schedulable accounts (force platform mode: try group first, fallback to all)
|
||||
accounts, err := s.listSchedulableAccountsOnce(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -158,56 +111,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// 混合调度模式下:原生平台直接通过,antigravity 需要启用 mixed_scheduling
|
||||
// 非混合调度模式(antigravity 分组):不需要过滤
|
||||
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// Prefer OAuth accounts when both are unused (more compatible for Code Assist flows).
|
||||
if acc.Type == AccountTypeOAuth && selected.Type != AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 4. 按优先级 + LRU 选择最佳账号
|
||||
// Select best account by priority + LRU
|
||||
selected := s.selectBestGeminiAccount(ctx, accounts, requestedModel, excludedIDs, platform, useMixedScheduling)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -216,6 +122,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return nil, errors.New("no available Gemini accounts")
|
||||
}
|
||||
|
||||
// 5. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||
}
|
||||
@@ -223,6 +131,229 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// resolvePlatformAndSchedulingMode 解析目标平台和调度模式。
|
||||
// 返回:平台名称、是否使用混合调度、是否强制平台、错误。
|
||||
//
|
||||
// resolvePlatformAndSchedulingMode resolves target platform and scheduling mode.
|
||||
// Returns: platform name, whether to use mixed scheduling, whether force platform, error.
|
||||
func (s *GeminiMessagesCompatService) resolvePlatformAndSchedulingMode(ctx context.Context, groupID *int64) (platform string, useMixedScheduling bool, hasForcePlatform bool, err error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, false, true, nil
|
||||
}
|
||||
|
||||
if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, false, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||
return group.Platform, group.Platform == PlatformGemini, false, nil
|
||||
}
|
||||
|
||||
// 无分组时只使用原生 gemini 平台
|
||||
return PlatformGemini, true, false, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account unavailable.
|
||||
func (s *GeminiMessagesCompatService) tryStickySessionHit(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
sessionHash, cacheKey, requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !s.isAccountUsableForRequest(ctx, account, requestedModel, platform, useMixedScheduling) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// isAccountUsableForRequest 检查账号是否可用于当前请求。
|
||||
// 验证:模型调度、模型支持、平台匹配、速率限制预检。
|
||||
//
|
||||
// isAccountUsableForRequest checks if account is usable for current request.
|
||||
// Validates: model scheduling, model support, platform matching, rate limit precheck.
|
||||
func (s *GeminiMessagesCompatService) isAccountUsableForRequest(
|
||||
ctx context.Context,
|
||||
account *Account,
|
||||
requestedModel, platform string,
|
||||
useMixedScheduling bool,
|
||||
) bool {
|
||||
// 检查模型调度能力
|
||||
// Check model scheduling capability
|
||||
if !account.IsSchedulableForModel(requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查平台匹配
|
||||
// Check platform matching
|
||||
if !s.isAccountValidForPlatform(account, platform, useMixedScheduling) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 速率限制预检
|
||||
// Rate limit precheck
|
||||
if !s.passesRateLimitPreCheck(ctx, account, requestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isAccountValidForPlatform 检查账号是否匹配目标平台。
|
||||
// 原生平台直接匹配;混合调度模式下 antigravity 需要启用 mixed_scheduling。
|
||||
//
|
||||
// isAccountValidForPlatform checks if account matches target platform.
|
||||
// Native platform matches directly; mixed scheduling mode requires antigravity to enable mixed_scheduling.
|
||||
func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account, platform string, useMixedScheduling bool) bool {
|
||||
if account.Platform == platform {
|
||||
return true
|
||||
}
|
||||
if useMixedScheduling && account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// passesRateLimitPreCheck 执行速率限制预检。
|
||||
// 返回 true 表示通过预检或无需预检。
|
||||
//
|
||||
// passesRateLimitPreCheck performs rate limit precheck.
|
||||
// Returns true if passed or precheck not required.
|
||||
func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool {
|
||||
if s.rateLimitService == nil || requestedModel == "" {
|
||||
return true
|
||||
}
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// selectBestGeminiAccount 从候选账号中选择最佳账号(优先级 + LRU + OAuth 优先)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestGeminiAccount selects best account from candidates (priority + LRU + OAuth preferred).
|
||||
// Returns nil if no available account.
|
||||
func (s *GeminiMessagesCompatService) selectBestGeminiAccount(
|
||||
ctx context.Context,
|
||||
accounts []Account,
|
||||
requestedModel string,
|
||||
excludedIDs map[int64]struct{},
|
||||
platform string,
|
||||
useMixedScheduling bool,
|
||||
) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查账号是否可用于当前请求
|
||||
if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择最佳账号
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterGeminiAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterGeminiAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。
|
||||
//
|
||||
// isBetterGeminiAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used (OAuth > non-OAuth) > least recently used.
|
||||
func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,优先选择 OAuth 账号(更兼容 Code Assist 流程)
|
||||
return candidate.Type == AccountTypeOAuth && current.Type != AccountTypeOAuth
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// isModelSupportedByAccount 根据账户平台检查模型支持
|
||||
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
|
||||
if account.Platform == PlatformAntigravity {
|
||||
@@ -1841,6 +1972,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
|
||||
var last map[string]any
|
||||
var lastWithParts map[string]any
|
||||
var collectedTextParts []string // Collect all text parts for aggregation
|
||||
usage := &ClaudeUsage{}
|
||||
|
||||
for {
|
||||
@@ -1852,7 +1984,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
switch payload {
|
||||
case "", "[DONE]":
|
||||
if payload == "[DONE]" {
|
||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||||
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||
}
|
||||
default:
|
||||
var parsed map[string]any
|
||||
@@ -1871,6 +2003,12 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
}
|
||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||
lastWithParts = parsed
|
||||
// Collect text from each part for aggregation
|
||||
for _, part := range parts {
|
||||
if text, ok := part["text"].(string); ok && text != "" {
|
||||
collectedTextParts = append(collectedTextParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1885,7 +2023,7 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
||||
}
|
||||
}
|
||||
|
||||
return pickGeminiCollectResult(last, lastWithParts), usage, nil
|
||||
return mergeCollectedTextParts(pickGeminiCollectResult(last, lastWithParts), collectedTextParts), usage, nil
|
||||
}
|
||||
|
||||
func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any) map[string]any {
|
||||
@@ -1898,6 +2036,83 @@ func pickGeminiCollectResult(last map[string]any, lastWithParts map[string]any)
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
// mergeCollectedTextParts merges all collected text chunks into the final response.
|
||||
// This fixes the issue where non-streaming responses only returned the last chunk
|
||||
// instead of the complete aggregated text.
|
||||
func mergeCollectedTextParts(response map[string]any, textParts []string) map[string]any {
|
||||
if len(textParts) == 0 {
|
||||
return response
|
||||
}
|
||||
|
||||
// Join all text parts
|
||||
mergedText := strings.Join(textParts, "")
|
||||
|
||||
// Deep copy response
|
||||
result := make(map[string]any)
|
||||
for k, v := range response {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
// Get or create candidates
|
||||
candidates, ok := result["candidates"].([]any)
|
||||
if !ok || len(candidates) == 0 {
|
||||
candidates = []any{map[string]any{}}
|
||||
}
|
||||
|
||||
// Get first candidate
|
||||
candidate, ok := candidates[0].(map[string]any)
|
||||
if !ok {
|
||||
candidate = make(map[string]any)
|
||||
candidates[0] = candidate
|
||||
}
|
||||
|
||||
// Get or create content
|
||||
content, ok := candidate["content"].(map[string]any)
|
||||
if !ok {
|
||||
content = map[string]any{"role": "model"}
|
||||
candidate["content"] = content
|
||||
}
|
||||
|
||||
// Get existing parts
|
||||
existingParts, ok := content["parts"].([]any)
|
||||
if !ok {
|
||||
existingParts = []any{}
|
||||
}
|
||||
|
||||
// Find and update first text part, or create new one
|
||||
newParts := make([]any, 0, len(existingParts)+1)
|
||||
textUpdated := false
|
||||
|
||||
for _, p := range existingParts {
|
||||
pm, ok := p.(map[string]any)
|
||||
if !ok {
|
||||
newParts = append(newParts, p)
|
||||
continue
|
||||
}
|
||||
if _, hasText := pm["text"]; hasText && !textUpdated {
|
||||
// Replace with merged text
|
||||
newPart := make(map[string]any)
|
||||
for k, v := range pm {
|
||||
newPart[k] = v
|
||||
}
|
||||
newPart["text"] = mergedText
|
||||
newParts = append(newParts, newPart)
|
||||
textUpdated = true
|
||||
} else {
|
||||
newParts = append(newParts, pm)
|
||||
}
|
||||
}
|
||||
|
||||
if !textUpdated {
|
||||
newParts = append([]any{map[string]any{"text": mergedText}}, newParts...)
|
||||
}
|
||||
|
||||
content["parts"] = newParts
|
||||
result["candidates"] = candidates
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type geminiNativeStreamResult struct {
|
||||
usage *ClaudeUsage
|
||||
firstTokenMs *int
|
||||
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
|
||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listByGroupFunc func(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
listByPlatformFunc func(ctx context.Context, platforms []string) ([]Account, error)
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
@@ -107,6 +109,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context,
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
if m.listByPlatformFunc != nil {
|
||||
return m.listByPlatformFunc(ctx, platforms)
|
||||
}
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
@@ -120,6 +125,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Contex
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
if m.listByGroupFunc != nil {
|
||||
return m.listByGroupFunc(ctx, groupID, platforms)
|
||||
}
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
@@ -215,6 +223,7 @@ var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
@@ -236,6 +245,18 @@ func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, group
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if m.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if m.deletedSessions == nil {
|
||||
m.deletedSessions = make(map[string]int)
|
||||
}
|
||||
m.deletedSessions[sessionHash]++
|
||||
delete(m.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -526,6 +547,274 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyS
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||
})
|
||||
|
||||
t.Run("粘性会话不可调度-清理并回退选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusDisabled, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, 1, cache.deletedSessions["gemini:session-123"])
|
||||
require.Equal(t, int64(2), cache.sessionBindings["gemini:session-123"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ForcePlatformFallback(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(9)
|
||||
ctx = context.WithValue(ctx, ctxkey.ForcePlatform, PlatformAntigravity)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByGroupFunc: func(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return nil, nil
|
||||
},
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
}, nil
|
||||
},
|
||||
accountsByID: map[int64]*Account{
|
||||
1: {ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformGemini,
|
||||
Priority: 1,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.0-pro": "gemini-1.0-pro"}},
|
||||
},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "supporting model")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickyMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true, Extra: map[string]any{"mixed_scheduling": true}},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-999": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-999", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_SkipDisabledMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ExcludedAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", excluded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_ListError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
listByPlatformFunc: func(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
return nil, errors.New("query failed")
|
||||
},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "query accounts failed")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferOAuth(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_PreferLeastRecentlyUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: &oldTime},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-pro", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -131,21 +132,32 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
// 3) Populate cache with TTL(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > geminiTokenCacheSkew:
|
||||
ttl = until - geminiTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > geminiTokenCacheSkew:
|
||||
ttl = until - geminiTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
|
||||
@@ -122,6 +122,7 @@ type TokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
OrgUUID string `json:"org_uuid,omitempty"`
|
||||
AccountUUID string `json:"account_uuid,omitempty"`
|
||||
EmailAddress string `json:"email_address,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
@@ -252,9 +253,15 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
if tokenResp.Account != nil {
|
||||
if tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
}
|
||||
if tokenResp.Account.EmailAddress != "" {
|
||||
tokenInfo.EmailAddress = tokenResp.Account.EmailAddress
|
||||
log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress)
|
||||
}
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
|
||||
@@ -180,67 +180,26 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||
// SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。
|
||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
cacheKey := "openai:" + sessionHash
|
||||
|
||||
// 1. 尝试粘性会话命中
|
||||
// Try sticky session hit
|
||||
if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
// 2. 获取可调度的 OpenAI 账号
|
||||
// Get schedulable OpenAI accounts
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 3. 按优先级 + LRU 选择最佳账号
|
||||
// Select by priority + LRU
|
||||
selected := s.selectBestAccount(accounts, requestedModel, excludedIDs)
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
@@ -249,14 +208,138 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
// 4. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。
|
||||
//
|
||||
// tryStickySessionHit attempts to get account from sticky session.
|
||||
// Returns account if hit and usable; clears session and returns nil if account is unavailable.
|
||||
func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
if sessionHash == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
if err != nil || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, excluded := excludedIDs[accountID]; excluded {
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查账号是否需要清理粘性会话
|
||||
// Check if sticky session should be cleared
|
||||
if shouldClearStickySession(account) {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 验证账号是否可用于当前请求
|
||||
// Verify account is usable for current request
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
return nil
|
||||
}
|
||||
if requestedModel != "" && !account.IsModelSupported(requestedModel) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 刷新会话 TTL 并返回账号
|
||||
// Refresh session TTL and return account
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL)
|
||||
return account
|
||||
}
|
||||
|
||||
// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。
|
||||
// 返回 nil 表示无可用账号。
|
||||
//
|
||||
// selectBestAccount selects the best account from candidates (priority + LRU).
|
||||
// Returns nil if no available account.
|
||||
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
|
||||
var selected *Account
|
||||
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
|
||||
// 跳过被排除的账号
|
||||
// Skip excluded accounts
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
|
||||
// 调度器快照可能暂时过时,这里重新检查可调度性和平台
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability and platform
|
||||
if !acc.IsSchedulable() || !acc.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 选择优先级最高且最久未使用的账号
|
||||
// Select highest priority and least recently used
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
|
||||
if s.isBetterAccount(acc, selected) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// isBetterAccount 判断 candidate 是否比 current 更优。
|
||||
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
|
||||
//
|
||||
// isBetterAccount checks if candidate is better than current.
|
||||
// Rules: higher priority (lower value) wins; same priority: never used > least recently used.
|
||||
func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool {
|
||||
// 优先级更高(数值更小)
|
||||
// Higher priority (lower value)
|
||||
if candidate.Priority < current.Priority {
|
||||
return true
|
||||
}
|
||||
if candidate.Priority > current.Priority {
|
||||
return false
|
||||
}
|
||||
|
||||
// 同优先级,比较最后使用时间
|
||||
// Same priority, compare last used time
|
||||
switch {
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt != nil:
|
||||
// candidate 从未使用,优先
|
||||
return true
|
||||
case candidate.LastUsedAt != nil && current.LastUsedAt == nil:
|
||||
// current 从未使用,保持
|
||||
return false
|
||||
case candidate.LastUsedAt == nil && current.LastUsedAt == nil:
|
||||
// 都未使用,保持
|
||||
return false
|
||||
default:
|
||||
// 都使用过,选择最久未使用的
|
||||
return candidate.LastUsedAt.Before(*current.LastUsedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
@@ -325,29 +408,35 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
if err == nil {
|
||||
clearSticky := shouldClearStickySession(account)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||
}
|
||||
if !clearSticky && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,19 +21,50 @@ type stubOpenAIAccountRepo struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
return &r.accounts[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
acquireResults map[int64]bool
|
||||
waitCounts map[int64]int
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -42,8 +73,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if c.loadBatchErr != nil {
|
||||
return nil, c.loadBatchErr
|
||||
}
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
if c.skipDefaultLoad && c.loadMap != nil {
|
||||
for _, acc := range accounts {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
for _, acc := range accounts {
|
||||
if c.loadMap != nil {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
continue
|
||||
}
|
||||
}
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
@@ -92,6 +140,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type stubGatewayCache struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if c.sessionBindings == nil {
|
||||
c.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
c.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if c.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if c.deletedSessions == nil {
|
||||
c.deletedSessions = make(map[string]int)
|
||||
}
|
||||
c.deletedSessions[sessionHash]++
|
||||
delete(c.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
@@ -182,6 +275,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-1"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", acc)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-2"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", selection)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unsupported model")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account for unsupported model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "supporting model") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection")
|
||||
}
|
||||
if selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
||||
}
|
||||
if cache.sessionBindings["openai:fallback"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan fallback")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
||||
sessionHash := "bind"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session binding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
||||
sessionHash := "sticky-wait"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
waitCounts: map[int64]int{1: 0},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected sticky wait plan")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 80},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
if cache.sessionBindings["openai:load"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
||||
sessionHash := "excluded"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
||||
sessionHash := "non-openai"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no accounts")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
resetAt := time.Now().Add(1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no candidates")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 100},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 50},
|
||||
},
|
||||
skipDefaultLoad: true,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
lastUsed := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
@@ -162,26 +162,37 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
// 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件)
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
// 版本过时,使用 DB 中的最新 token
|
||||
slog.Debug("openai_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetOpenAIAccessToken()
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
// 不写入缓存,让下次请求重新处理
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
// 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
|
||||
ttl = time.Minute
|
||||
slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > openAITokenCacheSkew:
|
||||
ttl = until - openAITokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
keys := []string{
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
@@ -88,6 +89,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"),
|
||||
@@ -125,6 +127,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
@@ -140,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
@@ -162,6 +166,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
// 注册设置
|
||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
@@ -248,6 +253,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsPromoCodeEnabled 检查是否启用优惠码功能
|
||||
func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled)
|
||||
if err != nil {
|
||||
return true // 默认启用
|
||||
}
|
||||
return value != "false"
|
||||
}
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
@@ -297,6 +311,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "TianShuAPI",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
@@ -328,6 +343,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
@@ -58,6 +59,7 @@ type SystemSettings struct {
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
|
||||
54
backend/internal/service/sticky_session_test.go
Normal file
54
backend/internal/service/sticky_session_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
//go:build unit
|
||||
|
||||
// Package service 提供 API 网关核心服务。
|
||||
// 本文件包含 shouldClearStickySession 函数的单元测试,
|
||||
// 验证粘性会话清理逻辑在各种账号状态下的正确行为。
|
||||
//
|
||||
// This file contains unit tests for the shouldClearStickySession function,
|
||||
// verifying correct sticky session clearing behavior under various account states.
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestShouldClearStickySession 测试粘性会话清理判断逻辑。
|
||||
// 验证在以下情况下是否正确判断需要清理粘性会话:
|
||||
// - nil 账号:不清理(返回 false)
|
||||
// - 状态为错误或禁用:清理
|
||||
// - 不可调度:清理
|
||||
// - 临时不可调度且未过期:清理
|
||||
// - 临时不可调度已过期:不清理
|
||||
// - 正常可调度状态:不清理
|
||||
//
|
||||
// TestShouldClearStickySession tests the sticky session clearing logic.
|
||||
// Verifies correct behavior for various account states including:
|
||||
// nil account, error/disabled status, unschedulable, temporary unschedulable.
|
||||
func TestShouldClearStickySession(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(1 * time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
want bool
|
||||
}{
|
||||
{name: "nil account", account: nil, want: false},
|
||||
{name: "status error", account: &Account{Status: StatusError, Schedulable: true}, want: true},
|
||||
{name: "status disabled", account: &Account{Status: StatusDisabled, Schedulable: true}, want: true},
|
||||
{name: "schedulable false", account: &Account{Status: StatusActive, Schedulable: false}, want: true},
|
||||
{name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, want: true},
|
||||
{name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, want: false},
|
||||
{name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, shouldClearStickySession(tt.account))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TokenCacheInvalidator interface {
|
||||
InvalidateToken(ctx context.Context, account *Account) error
|
||||
@@ -24,18 +28,87 @@ func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, ac
|
||||
return nil
|
||||
}
|
||||
|
||||
var cacheKey string
|
||||
var keysToDelete []string
|
||||
accountIDKey := "account:" + strconv.FormatInt(account.ID, 10)
|
||||
|
||||
switch account.Platform {
|
||||
case PlatformGemini:
|
||||
cacheKey = GeminiTokenCacheKey(account)
|
||||
// Gemini 可能有两种缓存键:project_id 或 account_id
|
||||
// 首次获取 token 时可能没有 project_id,之后自动检测到 project_id 后会使用新 key
|
||||
// 刷新时需要同时删除两种可能的 key,确保不会遗留旧缓存
|
||||
keysToDelete = append(keysToDelete, GeminiTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "gemini:"+accountIDKey)
|
||||
case PlatformAntigravity:
|
||||
cacheKey = AntigravityTokenCacheKey(account)
|
||||
// Antigravity 同样可能有两种缓存键
|
||||
keysToDelete = append(keysToDelete, AntigravityTokenCacheKey(account))
|
||||
keysToDelete = append(keysToDelete, "ag:"+accountIDKey)
|
||||
case PlatformOpenAI:
|
||||
cacheKey = OpenAITokenCacheKey(account)
|
||||
keysToDelete = append(keysToDelete, OpenAITokenCacheKey(account))
|
||||
case PlatformAnthropic:
|
||||
cacheKey = ClaudeTokenCacheKey(account)
|
||||
keysToDelete = append(keysToDelete, ClaudeTokenCacheKey(account))
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return c.cache.DeleteAccessToken(ctx, cacheKey)
|
||||
|
||||
// 删除所有可能的缓存键(去重后)
|
||||
seen := make(map[string]bool)
|
||||
for _, key := range keysToDelete {
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
if err := c.cache.DeleteAccessToken(ctx, key); err != nil {
|
||||
slog.Warn("token_cache_delete_failed", "key", key, "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckTokenVersion 检查 account 的 token 版本是否已过时,并返回最新的 account
|
||||
// 用于解决异步刷新任务与请求线程的竞态条件:
|
||||
// 如果刷新任务已更新 token 并删除缓存,此时请求线程的旧 account 对象不应写入缓存
|
||||
//
|
||||
// 返回值:
|
||||
// - latestAccount: 从 DB 获取的最新 account(如果查询失败则返回 nil)
|
||||
// - isStale: true 表示 token 已过时(应使用 latestAccount),false 表示可以使用当前 account
|
||||
func CheckTokenVersion(ctx context.Context, account *Account, repo AccountRepository) (latestAccount *Account, isStale bool) {
|
||||
if account == nil || repo == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
currentVersion := account.GetCredentialAsInt64("_token_version")
|
||||
|
||||
latestAccount, err := repo.GetByID(ctx, account.ID)
|
||||
if err != nil || latestAccount == nil {
|
||||
// 查询失败,默认允许缓存,不返回 latestAccount
|
||||
return nil, false
|
||||
}
|
||||
|
||||
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
|
||||
|
||||
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
|
||||
// 说明异步刷新任务已更新 token,当前 account 已过时
|
||||
if currentVersion == 0 && latestVersion > 0 {
|
||||
slog.Debug("token_version_stale_no_current_version",
|
||||
"account_id", account.ID,
|
||||
"latest_version", latestVersion)
|
||||
return latestAccount, true
|
||||
}
|
||||
|
||||
// 情况2: 两边都没有版本号,说明从未被异步刷新过,允许缓存
|
||||
if currentVersion == 0 && latestVersion == 0 {
|
||||
return latestAccount, false
|
||||
}
|
||||
|
||||
// 情况3: 比较版本号,如果 DB 中的版本更新,当前 account 已过时
|
||||
if latestVersion > currentVersion {
|
||||
slog.Debug("token_version_stale",
|
||||
"account_id", account.ID,
|
||||
"current_version", currentVersion,
|
||||
"latest_version", latestVersion)
|
||||
return latestAccount, true
|
||||
}
|
||||
|
||||
return latestAccount, false
|
||||
}
|
||||
|
||||
@@ -51,7 +51,27 @@ func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
|
||||
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
|
||||
// 这是为了处理:首次获取 token 时可能没有 project_id,之后自动检测到后会使用新 key
|
||||
require.Equal(t, []string{"gemini:project-x", "gemini:account:10"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_GeminiWithoutProjectID(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "gemini-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// 没有 project_id 时,两个 key 相同,去重后只删除一个
|
||||
require.Equal(t, []string{"gemini:account:10"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||
@@ -68,7 +88,26 @@ func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
|
||||
// 新行为:同时删除基于 project_id 和 account_id 的缓存键
|
||||
require.Equal(t, []string{"ag:ag-project", "ag:account:99"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_AntigravityWithoutProjectID(t *testing.T) {
|
||||
cache := &geminiTokenCacheStub{}
|
||||
invalidator := NewCompositeTokenCacheInvalidator(cache)
|
||||
account := &Account{
|
||||
ID: 99,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "ag-token",
|
||||
},
|
||||
}
|
||||
|
||||
err := invalidator.InvalidateToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
// 没有 project_id 时,两个 key 相同,去重后只删除一个
|
||||
require.Equal(t, []string{"ag:account:99"}, cache.deletedKeys)
|
||||
}
|
||||
|
||||
func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
|
||||
@@ -233,9 +272,10 @@ func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 新行为:删除失败只记录日志,不返回错误
|
||||
// 这是因为缓存失效失败不应影响主业务流程
|
||||
err := invalidator.InvalidateToken(context.Background(), tt.account)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, expectedErr, err)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -252,9 +292,12 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
|
||||
{ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
|
||||
}
|
||||
|
||||
// 新行为:Gemini 和 Antigravity 会同时删除基于 project_id 和 account_id 的键
|
||||
expectedKeys := []string{
|
||||
"gemini:gemini-proj",
|
||||
"gemini:account:1",
|
||||
"ag:ag-proj",
|
||||
"ag:account:2",
|
||||
"openai:account:3",
|
||||
"claude:account:4",
|
||||
}
|
||||
@@ -266,3 +309,239 @@ func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
|
||||
|
||||
require.Equal(t, expectedKeys, cache.deletedKeys)
|
||||
}
|
||||
|
||||
// ========== GetCredentialAsInt64 测试 ==========
|
||||
|
||||
func TestAccount_GetCredentialAsInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
key string
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "int64_value",
|
||||
credentials: map[string]any{"_token_version": int64(1737654321000)},
|
||||
key: "_token_version",
|
||||
expected: 1737654321000,
|
||||
},
|
||||
{
|
||||
name: "float64_value",
|
||||
credentials: map[string]any{"_token_version": float64(1737654321000)},
|
||||
key: "_token_version",
|
||||
expected: 1737654321000,
|
||||
},
|
||||
{
|
||||
name: "int_value",
|
||||
credentials: map[string]any{"_token_version": 12345},
|
||||
key: "_token_version",
|
||||
expected: 12345,
|
||||
},
|
||||
{
|
||||
name: "string_value",
|
||||
credentials: map[string]any{"_token_version": "1737654321000"},
|
||||
key: "_token_version",
|
||||
expected: 1737654321000,
|
||||
},
|
||||
{
|
||||
name: "string_with_spaces",
|
||||
credentials: map[string]any{"_token_version": " 1737654321000 "},
|
||||
key: "_token_version",
|
||||
expected: 1737654321000,
|
||||
},
|
||||
{
|
||||
name: "nil_credentials",
|
||||
credentials: nil,
|
||||
key: "_token_version",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "missing_key",
|
||||
credentials: map[string]any{"other_key": 123},
|
||||
key: "_token_version",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "nil_value",
|
||||
credentials: map[string]any{"_token_version": nil},
|
||||
key: "_token_version",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid_string",
|
||||
credentials: map[string]any{"_token_version": "not_a_number"},
|
||||
key: "_token_version",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "empty_string",
|
||||
credentials: map[string]any{"_token_version": ""},
|
||||
key: "_token_version",
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Credentials: tt.credentials}
|
||||
result := account.GetCredentialAsInt64(tt.key)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccount_GetCredentialAsInt64_NilAccount(t *testing.T) {
|
||||
var account *Account
|
||||
result := account.GetCredentialAsInt64("_token_version")
|
||||
require.Equal(t, int64(0), result)
|
||||
}
|
||||
|
||||
// ========== CheckTokenVersion 测试 ==========
|
||||
|
||||
func TestCheckTokenVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
latestAccount *Account
|
||||
repoErr error
|
||||
expectedStale bool
|
||||
}{
|
||||
{
|
||||
name: "nil_account",
|
||||
account: nil,
|
||||
latestAccount: nil,
|
||||
expectedStale: false,
|
||||
},
|
||||
{
|
||||
name: "no_version_in_account_but_db_has_version",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
latestAccount: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
expectedStale: true, // 当前 account 无版本但 DB 有,说明已被异步刷新,当前已过时
|
||||
},
|
||||
{
|
||||
name: "both_no_version",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
latestAccount: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expectedStale: false, // 两边都没有版本号,说明从未被异步刷新过,允许缓存
|
||||
},
|
||||
{
|
||||
name: "same_version",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
latestAccount: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
expectedStale: false,
|
||||
},
|
||||
{
|
||||
name: "current_version_newer",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(200)},
|
||||
},
|
||||
latestAccount: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
expectedStale: false,
|
||||
},
|
||||
{
|
||||
name: "current_version_older_stale",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
latestAccount: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(200)},
|
||||
},
|
||||
expectedStale: true, // 当前版本过时
|
||||
},
|
||||
{
|
||||
name: "repo_error",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
latestAccount: nil,
|
||||
repoErr: errors.New("db error"),
|
||||
expectedStale: false, // 查询失败,默认允许缓存
|
||||
},
|
||||
{
|
||||
name: "repo_returns_nil",
|
||||
account: &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
},
|
||||
latestAccount: nil,
|
||||
repoErr: nil,
|
||||
expectedStale: false, // 查询返回 nil,默认允许缓存
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 由于 CheckTokenVersion 接受 AccountRepository 接口,而创建完整的 mock 很繁琐
|
||||
// 这里我们直接测试函数的核心逻辑来验证行为
|
||||
|
||||
if tt.name == "nil_account" {
|
||||
_, isStale := CheckTokenVersion(context.Background(), nil, nil)
|
||||
require.Equal(t, tt.expectedStale, isStale)
|
||||
return
|
||||
}
|
||||
|
||||
// 模拟 CheckTokenVersion 的核心逻辑
|
||||
account := tt.account
|
||||
currentVersion := account.GetCredentialAsInt64("_token_version")
|
||||
|
||||
// 模拟 repo 查询
|
||||
latestAccount := tt.latestAccount
|
||||
if tt.repoErr != nil || latestAccount == nil {
|
||||
require.Equal(t, tt.expectedStale, false)
|
||||
return
|
||||
}
|
||||
|
||||
latestVersion := latestAccount.GetCredentialAsInt64("_token_version")
|
||||
|
||||
// 情况1: 当前 account 没有版本号,但 DB 中已有版本号
|
||||
if currentVersion == 0 && latestVersion > 0 {
|
||||
require.Equal(t, tt.expectedStale, true)
|
||||
return
|
||||
}
|
||||
|
||||
// 情况2: 两边都没有版本号
|
||||
if currentVersion == 0 && latestVersion == 0 {
|
||||
require.Equal(t, tt.expectedStale, false)
|
||||
return
|
||||
}
|
||||
|
||||
// 情况3: 比较版本号
|
||||
isStale := latestVersion > currentVersion
|
||||
require.Equal(t, tt.expectedStale, isStale)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTokenVersion_NilRepo(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Credentials: map[string]any{"_token_version": int64(100)},
|
||||
}
|
||||
_, isStale := CheckTokenVersion(context.Background(), account, nil)
|
||||
require.False(t, isStale) // nil repo,默认允许缓存
|
||||
}
|
||||
|
||||
@@ -169,6 +169,10 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
// 记录刷新版本时间戳,用于解决缓存一致性问题
|
||||
// TokenProvider 写入缓存前会检查此版本,如果版本已更新则跳过写入
|
||||
newCredentials["_token_version"] = time.Now().UnixMilli()
|
||||
|
||||
account.Credentials = newCredentials
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
|
||||
@@ -345,6 +345,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.deleteCalls, 3)
|
||||
require.Equal(t, 2, repo.deleteCalls[0].limit)
|
||||
require.True(t, repo.deleteCalls[0].filters.StartTime.Equal(start))
|
||||
require.True(t, repo.deleteCalls[0].filters.EndTime.Equal(end))
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
||||
|
||||
Reference in New Issue
Block a user