Merge upstream/main
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -161,25 +162,32 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
|
||||
// Get subscription (may be nil)
|
||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||
|
||||
// 获取 User-Agent
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// For Gemini native API, do not send Claude-style ping frames.
|
||||
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
|
||||
|
||||
// 0) wait queue check
|
||||
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
|
||||
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
|
||||
waitCounted := false
|
||||
if err != nil {
|
||||
log.Printf("Increment wait count failed: %v", err)
|
||||
} else if !canWait {
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
if err == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
// 1) user concurrency slot
|
||||
streamStarted := false
|
||||
@@ -188,6 +196,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if waitCounted {
|
||||
geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
@@ -203,10 +215,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
|
||||
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
@@ -218,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
lastFailoverStatus := 0
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
@@ -228,15 +236,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
var accountWaitRelease func()
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
accountWaitCounted := false
|
||||
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
log.Printf("Increment account wait count failed: %v", err)
|
||||
@@ -244,12 +253,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
log.Printf("Account wait queue full: account=%d", account.ID)
|
||||
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
|
||||
return
|
||||
} else {
|
||||
// Only set release function if increment succeeded
|
||||
accountWaitRelease = func() {
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if accountWaitCounted {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -260,19 +272,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
|
||||
|
||||
// 5) forward (根据平台分流)
|
||||
var result *service.ForwardResult
|
||||
@@ -284,9 +296,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
if accountWaitRelease != nil {
|
||||
accountWaitRelease()
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
@@ -306,8 +315,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 6) record usage async
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
@@ -317,10 +330,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: ip,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent)
|
||||
}(result, account, userAgent, clientIP)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user