Add invalid-request fallback routing
This commit is contained in:
@@ -35,11 +35,12 @@ type CreateGroupRequest struct {
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
@@ -58,11 +59,12 @@ type UpdateGroupRequest struct {
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
@@ -155,22 +157,23 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -196,23 +199,24 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -73,27 +73,28 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,8 @@ type Group struct {
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 无效请求兜底分组
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -325,136 +326,186 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
currentAPIKey := apiKey
|
||||
currentSubscription := subscription
|
||||
var fallbackGroupID *int64
|
||||
if apiKey.Group != nil {
|
||||
fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest
|
||||
}
|
||||
fallbackUsed := false
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
retryWithFallback := false
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
|
||||
}
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if err != nil {
|
||||
var promptTooLongErr *service.PromptTooLongError
|
||||
if errors.As(err, &promptTooLongErr) {
|
||||
log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed)
|
||||
if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 {
|
||||
fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID)
|
||||
if err != nil {
|
||||
log.Printf("Resolve fallback group failed: %v", err)
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
if fallbackGroup.Platform != service.PlatformAnthropic ||
|
||||
fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription ||
|
||||
fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType)
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
// 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
currentAPIKey = fallbackAPIKey
|
||||
currentSubscription = nil
|
||||
fallbackUsed = true
|
||||
retryWithFallback = true
|
||||
break
|
||||
}
|
||||
_ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body)
|
||||
return
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
if !retryWithFallback {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,6 +569,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey {
|
||||
if apiKey == nil || group == nil {
|
||||
return apiKey
|
||||
}
|
||||
cloned := *apiKey
|
||||
groupID := group.ID
|
||||
cloned.GroupID = &groupID
|
||||
cloned.Group = group
|
||||
return &cloned
|
||||
}
|
||||
|
||||
// Usage handles getting account balance for CC Switch integration
|
||||
// GET /v1/usage
|
||||
func (h *GatewayHandler) Usage(c *gin.Context) {
|
||||
|
||||
@@ -136,6 +136,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
)
|
||||
@@ -406,28 +407,29 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
@@ -116,6 +117,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
|
||||
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
|
||||
}
|
||||
|
||||
// 处理 ModelRouting:nil 时清除,否则设置
|
||||
if groupIn.ModelRouting != nil {
|
||||
|
||||
@@ -108,6 +108,8 @@ type CreateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
@@ -130,6 +132,8 @@ type UpdateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
@@ -572,24 +576,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
|
||||
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
// 校验无效请求兜底分组
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
}
|
||||
}
|
||||
|
||||
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
|
||||
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||
// platform/subscriptionType: 当前分组的有效平台/订阅类型
|
||||
// fallbackGroupID: 兜底分组 ID
|
||||
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
|
||||
if platform != PlatformAnthropic && platform != PlatformAntigravity {
|
||||
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
|
||||
}
|
||||
if subscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("subscription groups cannot set invalid request fallback")
|
||||
}
|
||||
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||
return fmt.Errorf("cannot set self as invalid request fallback group")
|
||||
}
|
||||
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
if fallbackGroup.Platform != PlatformAnthropic {
|
||||
return fmt.Errorf("fallback group must be anthropic platform")
|
||||
}
|
||||
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("fallback group cannot be subscription type")
|
||||
}
|
||||
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.FallbackGroupID = nil
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
|
||||
if input.FallbackGroupIDOnInvalidRequest != nil {
|
||||
if *input.FallbackGroupIDOnInvalidRequest > 0 {
|
||||
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
|
||||
} else {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
}
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
|
||||
@@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
updated *Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback *Group
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "openai_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "antigravity_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "subscription_group",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
wantMessage: "fallback group cannot be subscription type",
|
||||
},
|
||||
{
|
||||
name: "nested_fallback",
|
||||
fallback: &Group{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||
},
|
||||
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fallbackID := tc.fallback.ID
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: tc.fallback,
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.wantMessage)
|
||||
require.Nil(t, repo.created)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group not found")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
zero := int64(0)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
clear := int64(0)
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
FallbackGroupIDOnInvalidRequest: &clear,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
@@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
RequestID string
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (e *PromptTooLongError) Error() string {
|
||||
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||||
}
|
||||
|
||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
@@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "prompt_too_long",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &PromptTooLongError{
|
||||
StatusCode: resp.StatusCode,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Body: respBody,
|
||||
}
|
||||
}
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
@@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isPromptTooLongError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
return strings.Contains(msg, "prompt is too long")
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parseNestedMessage := func(msg string) string {
|
||||
trimmed := strings.TrimSpace(msg)
|
||||
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
|
||||
return ""
|
||||
}
|
||||
var nested map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
|
||||
return ""
|
||||
}
|
||||
if errObj, ok := nested["error"].(map[string]any); ok {
|
||||
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
}
|
||||
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Google-style: {"error": {"message": "..."}}
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: top-level message
|
||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
statusStr := "UNKNOWN"
|
||||
switch status {
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||
}
|
||||
|
||||
func TestIsPromptTooLongError(t *testing.T) {
|
||||
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||||
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||||
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||||
}
|
||||
|
||||
type httpUpstreamStub struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-opus-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"max_tokens": 1,
|
||||
"stream": false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Nil(t, result)
|
||||
|
||||
var promptErr *PromptTooLongError
|
||||
require.ErrorAs(t, err, &promptErr)
|
||||
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||||
require.Equal(t, "req-1", promptErr.RequestID)
|
||||
require.NotEmpty(t, promptErr.Body)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||||
}
|
||||
|
||||
@@ -23,20 +23,21 @@ type APIKeyAuthUserSnapshot struct {
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
|
||||
@@ -207,22 +207,23 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -250,23 +251,24 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
|
||||
@@ -55,6 +55,15 @@ func shortSessionHash(sessionHash string) string {
|
||||
return sessionHash[:8]
|
||||
}
|
||||
|
||||
func normalizeClaudeModelForAnthropic(requestedModel string) string {
|
||||
for _, prefix := range anthropicPrefixMappings {
|
||||
if strings.HasPrefix(requestedModel, prefix) {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
@@ -71,6 +80,12 @@ var (
|
||||
"You are a file search specialist for Claude Code", // Explore Agent 版
|
||||
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
||||
}
|
||||
|
||||
anthropicPrefixMappings = []string{
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
}
|
||||
)
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
@@ -951,6 +966,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
return s.resolveGroupByID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||
return nil
|
||||
@@ -1016,7 +1035,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
@@ -1719,6 +1738,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
// Antigravity 平台使用专门的模型支持检查
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
if account.Platform == PlatformAnthropic {
|
||||
requestedModel = normalizeClaudeModelForAnthropic(requestedModel)
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
@@ -2115,17 +2137,29 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||
originalModel := reqModel
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
@@ -3426,16 +3460,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||
if reqModel != "" {
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
|
||||
@@ -29,6 +29,8 @@ type Group struct {
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
// 无效请求兜底分组(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
|
||||
Reference in New Issue
Block a user