feat(backend): 添加 Gemini V1beta Handler 和路由

- 新增 gemini_v1beta_handler.go: 代理原生 Google API 格式
- 更新 gemini_oauth_handler.go: 移除 redirectUri,新增 oauthType
- 更新 account_handler.go: 账户 Handler 增强
- 更新 router.go: 注册 v1beta 路由
- 更新 config.go: Gemini OAuth 通过环境变量配置
- 更新 wire_gen.go: 依赖注入
This commit is contained in:
ianshaw
2025-12-25 21:24:53 -08:00
parent b2d71da2a2
commit 46cb82bac0
7 changed files with 413 additions and 22 deletions

View File

@@ -85,8 +85,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
geminiTokenCache := repository.NewGeminiTokenCache(client)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, geminiTokenProvider, httpUpstream)
concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
@@ -115,8 +117,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream)
geminiTokenCache := repository.NewGeminiTokenCache(client)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream)
@@ -226,6 +226,10 @@ func provideCleanup(
services.OpenAIOAuth.Stop()
return nil
}},
{"GeminiOAuthService", func() error {
services.GeminiOAuth.Stop()
return nil
}},
{"Redis", func() error {
return rdb.Close()
}},

View File

@@ -221,12 +221,14 @@ func setDefaults() {
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新适配Google 1小时token
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
// Gemini (optional)
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
// Default: uses Gemini CLI public credentials (set via environment)
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")

View File

@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -879,6 +880,44 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
// Handle Gemini accounts
if account.IsGemini() {
// For OAuth accounts: return default Gemini models
if account.IsOAuth() {
response.Success(c, geminicli.DefaultModels)
return
}
// For API Key accounts: return models based on model_mapping
mapping := account.GetModelMapping()
if len(mapping) == 0 {
response.Success(c, geminicli.DefaultModels)
return
}
var models []geminicli.Model
for requestedModel := range mapping {
var found bool
for _, dm := range geminicli.DefaultModels {
if dm.ID == requestedModel {
models = append(models, dm)
found = true
break
}
}
if !found {
models = append(models, geminicli.Model{
ID: requestedModel,
Type: "model",
DisplayName: requestedModel,
CreatedAt: "",
})
}
}
response.Success(c, models)
return
}
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {

View File

@@ -1,6 +1,10 @@
package admin
import (
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -16,8 +20,11 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi
}
type GeminiGenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
RedirectURI string `json:"redirect_uri" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
ProjectID string `json:"project_id"`
// OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id)
// 默认为 "code_assist" 以保持向后兼容
OAuthType string `json:"oauth_type"`
}
// GenerateAuthURL generates Google OAuth authorization URL for Gemini.
@@ -29,9 +36,31 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
return
}
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
redirectURI := deriveGeminiRedirectURI(c)
if oauthType == "ai_studio" {
// AI Studio OAuth uses a localhost redirect URI to support the "copy/paste callback URL"
// flow (no server-side callback endpoint needed).
redirectURI = geminicli.AIStudioOAuthRedirectURI
}
result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}
response.InternalError(c, "Failed to generate auth URL: "+msg)
return
}
@@ -39,11 +68,12 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
}
type GeminiExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
RedirectURI string `json:"redirect_uri" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
SessionID string `json:"session_id" binding:"required"`
State string `json:"state" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
// OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致
OAuthType string `json:"oauth_type"`
}
// ExchangeCode exchanges authorization code for tokens.
@@ -55,12 +85,22 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
return
}
// 默认使用 code_assist 以保持向后兼容
oauthType := strings.TrimSpace(req.OAuthType)
if oauthType == "" {
oauthType = "code_assist"
}
if oauthType != "code_assist" && oauthType != "ai_studio" {
response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'")
return
}
tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
SessionID: req.SessionID,
State: req.State,
Code: req.Code,
ProxyID: req.ProxyID,
OAuthType: oauthType,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
@@ -69,3 +109,25 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) {
response.Success(c, tokenInfo)
}
func deriveGeminiRedirectURI(c *gin.Context) string {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
return strings.TrimRight(origin, "/") + "/auth/callback"
}
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" {
scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0])
}
host := strings.TrimSpace(c.Request.Host)
if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" {
host = strings.TrimSpace(strings.Split(xfHost, ",")[0])
}
return fmt.Sprintf("%s://%s/auth/callback", scheme, host)
}

View File

@@ -0,0 +1,273 @@
package handler
import (
"context"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/middleware"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models")
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaGetModel proxies:
// GET /v1beta/models/{model}
func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
googleError(c, http.StatusBadRequest, "Missing model in URL")
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName)
if err != nil {
googleError(c, http.StatusBadGateway, err.Error())
return
}
if shouldFallbackGeminiModels(res) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
writeUpstreamResponse(c, res)
}
// GeminiV1BetaModels proxies Gemini native REST endpoints like:
// POST /v1beta/models/{model}:generateContent
// POST /v1beta/models/{model}:streamGenerateContent?alt=sse
func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok || apiKey == nil {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok || user == nil {
googleError(c, http.StatusInternalServerError, "User context not found")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
if err != nil {
googleError(c, http.StatusNotFound, err.Error())
return
}
stream := action == "streamGenerateContent"
body, err := io.ReadAll(c.Request.Body)
if err != nil {
googleError(c, http.StatusBadRequest, "Failed to read request body")
return
}
if len(body) == 0 {
googleError(c, http.StatusBadRequest, "Request body is empty")
return
}
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
// 0) wait queue check
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
} else if !canWait {
googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later")
return
}
defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID)
// 1) user concurrency slot
streamStarted := false
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
return
}
// 3) select account (sticky session based on request body)
sessionHash := h.gatewayService.GenerateSessionHash(body)
account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName)
if err != nil {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
// 4) account concurrency slot
accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 5) forward (writes response to client)
result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body)
if err != nil {
// ForwardNative already wrote the response
log.Printf("Gemini native forward failed: %v", err)
return
}
// 6) record usage async
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
func parseGeminiModelAction(rest string) (model string, action string, err error) {
rest = strings.TrimSpace(rest)
if rest == "" {
return "", "", &pathParseError{"missing path"}
}
// Standard: {model}:{action}
if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
// Fallback: {model}/{action}
if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 {
return rest[:i], rest[i+1:], nil
}
return "", "", &pathParseError{"invalid model action path"}
}
type pathParseError struct{ msg string }
func (e *pathParseError) Error() string { return e.msg }
func googleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
if res == nil {
googleError(c, http.StatusBadGateway, "Empty upstream response")
return
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
continue
}
for _, v := range vv {
c.Writer.Header().Add(k, v)
}
}
contentType := res.Headers.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}
c.Data(res.StatusCode, contentType, res.Body)
}
func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
if res == nil {
return true
}
if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden {
return false
}
if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") {
return true
}
if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") {
return true
}
return false
}

View File

@@ -314,6 +314,16 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(s.ApiKey, s.Subscription))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
// Gin treats ":" as a param marker, but Gemini uses "{model}:{action}" in the same segment.
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses)
}

View File

@@ -27,6 +27,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" {
c.Next()