Files
xinghuoapi/backend/internal/handler/gemini_v1beta_handler.go
IanShaw027 9c88980483 fix(lint): 修复 golangci-lint 报错
- 修复 gofmt 格式问题
- 修复 staticcheck SA4031 nil check 问题(只在成功时设置 release 函数)
- 删除未使用的 sortAccountsByPriority 函数
2026-01-01 04:26:01 +08:00

407 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"context"
"errors"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
// 强制 antigravity 模式:直接返回静态模型列表
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型列表
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
googleError(c, http.StatusBadRequest, "Missing model in URL")
return
}
// 强制 antigravity 模式:直接返回静态模型信息
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
hasAntigravity, _ := h.geminiCompatService.HasAntigravityAccounts(c.Request.Context(), apiKey.GroupID)
if hasAntigravity {
// antigravity 账户使用静态模型信息
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
authSubject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil {
googleError(c, http.StatusNotFound, err.Error())
return
}
stream := action == "streamGenerateContent"
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit))
return
}
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
if len(body) == 0 {
googleError(c, http.StatusBadRequest, "Request body is empty")
return
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), authSubject.UserID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
// 3) select account (sticky session based on request body)
parsedReq, _ := service.ParseGatewayRequest(body)
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
const maxAccountSwitches = 3
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
account := selection.Account
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
var accountWaitRelease func()
if !selection.Acquired {
if selection.WaitPlan == nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
} else {
// Only set release function if increment succeeded
accountWaitRelease = func() {
geminiConcurrency.DecrementAccountWaitCount(c.Request.Context(), account.ID)
}
}
accountReleaseFunc, err = geminiConcurrency.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
stream,
&streamStarted,
)
if err != nil {
if accountWaitRelease != nil {
accountWaitRelease()
}
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 5) forward (根据平台分流)
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body)
} else {
result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
}
if accountWaitRelease != nil {
accountWaitRelease()
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
}
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
func parseGeminiModelAction(rest string) (model string, action string, err error) {
rest = strings.TrimSpace(rest)
if rest == "" {
return "", "", &pathParseError{"missing path"}
}
// Standard: {model}:{action}
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
// Fallback: {model}/{action}
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
return "", "", &pathParseError{"invalid model action path"}
}
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}
func mapGeminiUpstreamError(statusCode int) (int, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "Upstream request failed"
}
}
type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg }
func googleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
if res == nil {
googleError(c, http.StatusBadGateway, "Empty upstream response")
return
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {
c.Writer.Header().Add(k, v)
}
}
contentType := res.Headers.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(res.StatusCode, contentType, res.Body)
}
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
if res == nil {
return true
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
return false
}
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
return true
}
return false
}