Merge branch 'main' into test-sora

This commit is contained in:
yangjianbo
2026-02-09 20:40:09 +08:00
104 changed files with 8062 additions and 2481 deletions

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -116,7 +117,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -205,6 +206,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
@@ -336,7 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -345,6 +351,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -484,7 +495,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if switchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
@@ -532,7 +543,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
if needForceCacheBilling(hasBoundSession, failoverErr) {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
@@ -541,6 +552,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
if account.Platform == service.PlatformAntigravity {
if !sleepFailoverDelay(c.Request.Context(), switchCount) {
return
}
}
continue
}
// 错误响应已在Forward中处理这里只记录日志
@@ -814,6 +830,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// needForceCacheBilling 判断 failover 时是否需要强制缓存计费
// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费
func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool {
return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling)
}
// sleepFailoverDelay 账号切换线性递增延时第1次0s、第2次1s、第3次2s…
// 返回 false 表示 context 已取消。
func sleepFailoverDelay(ctx context.Context, switchCount int) bool {
delay := time.Duration(switchCount-1) * time.Second
if delay <= 0 {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(delay):
return true
}
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
@@ -947,7 +984,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
@@ -975,6 +1012,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
// 计算粘性会话 hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
UserAgent: c.GetHeader("User-Agent"),
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
// 选择支持该模型的账号