Merge upstream/main
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -100,6 +101,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
@@ -108,8 +114,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -123,12 +128,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 0. 检查wait队列是否已满
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
// On error, allow request to proceed
|
||||
@@ -136,8 +139,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
// 确保在函数退出时减少wait计数
|
||||
defer h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
// Ensure we decrement if we exit before acquiring the user slot.
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
// 1. 首先获取用户并发槽位
|
||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
@@ -146,6 +156,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
// User slot acquired: no longer waiting in the queue.
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
@@ -182,7 +197,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -192,6 +207,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
@@ -208,12 +224,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
@@ -221,12 +237,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -237,20 +257,21 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
@@ -262,19 +283,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -284,8 +301,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -295,10 +316,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -310,7 +332,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
@@ -320,6 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
@@ -336,12 +359,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
@@ -349,12 +372,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -365,20 +391,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
var result *service.ForwardResult
|
||||
@@ -390,19 +416,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -412,8 +434,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -423,10 +449,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -692,21 +719,22 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
@@ -727,6 +755,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 转发请求(不记录使用量)
|
||||
if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user