diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 6ed9897b..06447ed0 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -85,8 +85,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { rateLimitService := service.NewRateLimitService(accountRepository, configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher() accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) + geminiTokenCache := repository.NewGeminiTokenCache(client) + geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) - accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, httpUpstream) + accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, geminiTokenProvider, httpUpstream) concurrencyCache := repository.NewConcurrencyCache(client) concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) @@ -115,8 +117,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityCache := repository.NewIdentityCache(client) identityService := service.NewIdentityService(identityCache) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, identityService, httpUpstream) - geminiTokenCache := repository.NewGeminiTokenCache(client) - geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, userService, concurrencyService, billingCacheService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, billingService, rateLimitService, billingCacheService, httpUpstream) @@ -226,6 +226,10 @@ func provideCleanup( services.OpenAIOAuth.Stop() return nil }}, + {"GeminiOAuthService", func() error { + services.GeminiOAuth.Stop() + return nil + }}, {"Redis", func() error { return rdb.Close() }}, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 2050a0cf..1b6825fd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -221,12 +221,14 @@ func setDefaults() { // TokenRefresh viper.SetDefault("token_refresh.enabled", true) - viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 - viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新 - viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 - viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 + viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) + viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 + viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 - // Gemini (optional) + // Gemini OAuth - configure via environment variables or config file + // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET + // Default: uses Gemini CLI public credentials (set via environment) viper.SetDefault("gemini.oauth.client_id", "") viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 8351d1af..13bc489d 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -6,6 +6,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -879,6 +880,44 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } + // Handle Gemini accounts + if account.IsGemini() { + // For OAuth accounts: return default Gemini models + if account.IsOAuth() { + response.Success(c, geminicli.DefaultModels) + return + } + + // For API Key accounts: return models based on model_mapping + mapping := account.GetModelMapping() + if len(mapping) == 0 { + response.Success(c, geminicli.DefaultModels) + return + } + + var models []geminicli.Model + for requestedModel := range mapping { + var found bool + for _, dm := range geminicli.DefaultModels { + if dm.ID == requestedModel { + models = append(models, dm) + found = true + break + } + } + if !found { + models = append(models, geminicli.Model{ + ID: requestedModel, + Type: "model", + DisplayName: requestedModel, + CreatedAt: "", + }) + } + } + response.Success(c, models) + return + } + // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/handler/admin/gemini_oauth_handler.go b/backend/internal/handler/admin/gemini_oauth_handler.go index 4d39700b..f6827735 100644 --- a/backend/internal/handler/admin/gemini_oauth_handler.go +++ b/backend/internal/handler/admin/gemini_oauth_handler.go @@ -1,6 +1,10 @@ package admin import ( + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -16,8 +20,11 @@ func NewGeminiOAuthHandler(geminiOAuthService *service.GeminiOAuthService) *Gemi } type GeminiGenerateAuthURLRequest struct { - ProxyID *int64 `json:"proxy_id"` - RedirectURI string `json:"redirect_uri" binding:"required"` + ProxyID *int64 `json:"proxy_id"` + ProjectID string `json:"project_id"` + // OAuth 类型: "code_assist" (需要 project_id) 或 "ai_studio" (不需要 project_id) + // 默认为 "code_assist" 以保持向后兼容 + OAuthType string `json:"oauth_type"` } // GenerateAuthURL generates Google OAuth authorization URL for Gemini. @@ -29,9 +36,31 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) { return } - result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + // 默认使用 code_assist 以保持向后兼容 + oauthType := strings.TrimSpace(req.OAuthType) + if oauthType == "" { + oauthType = "code_assist" + } + if oauthType != "code_assist" && oauthType != "ai_studio" { + response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'") + return + } + + redirectURI := deriveGeminiRedirectURI(c) + if oauthType == "ai_studio" { + // AI Studio OAuth uses a localhost redirect URI to support the "copy/paste callback URL" + // flow (no server-side callback endpoint needed). + redirectURI = geminicli.AIStudioOAuthRedirectURI + } + result, err := h.geminiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, redirectURI, req.ProjectID, oauthType) if err != nil { - response.InternalError(c, "Failed to generate auth URL: "+err.Error()) + msg := err.Error() + // Treat missing/invalid OAuth client configuration as a user/config error. + if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") { + response.BadRequest(c, "Failed to generate auth URL: "+msg) + return + } + response.InternalError(c, "Failed to generate auth URL: "+msg) return } @@ -39,11 +68,12 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) { } type GeminiExchangeCodeRequest struct { - SessionID string `json:"session_id" binding:"required"` - State string `json:"state" binding:"required"` - Code string `json:"code" binding:"required"` - RedirectURI string `json:"redirect_uri" binding:"required"` - ProxyID *int64 `json:"proxy_id"` + SessionID string `json:"session_id" binding:"required"` + State string `json:"state" binding:"required"` + Code string `json:"code" binding:"required"` + ProxyID *int64 `json:"proxy_id"` + // OAuth 类型: "code_assist" 或 "ai_studio",需要与 GenerateAuthURL 时的类型一致 + OAuthType string `json:"oauth_type"` } // ExchangeCode exchanges authorization code for tokens. @@ -55,12 +85,22 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) { return } + // 默认使用 code_assist 以保持向后兼容 + oauthType := strings.TrimSpace(req.OAuthType) + if oauthType == "" { + oauthType = "code_assist" + } + if oauthType != "code_assist" && oauthType != "ai_studio" { + response.BadRequest(c, "Invalid oauth_type: must be 'code_assist' or 'ai_studio'") + return + } + tokenInfo, err := h.geminiOAuthService.ExchangeCode(c.Request.Context(), &service.GeminiExchangeCodeInput{ - SessionID: req.SessionID, - State: req.State, - Code: req.Code, - RedirectURI: req.RedirectURI, - ProxyID: req.ProxyID, + SessionID: req.SessionID, + State: req.State, + Code: req.Code, + ProxyID: req.ProxyID, + OAuthType: oauthType, }) if err != nil { response.BadRequest(c, "Failed to exchange code: "+err.Error()) @@ -69,3 +109,25 @@ func (h *GeminiOAuthHandler) ExchangeCode(c *gin.Context) { response.Success(c, tokenInfo) } + +func deriveGeminiRedirectURI(c *gin.Context) string { + origin := strings.TrimSpace(c.GetHeader("Origin")) + if origin != "" { + return strings.TrimRight(origin, "/") + "/auth/callback" + } + + scheme := "http" + if c.Request.TLS != nil { + scheme = "https" + } + if xfProto := strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")); xfProto != "" { + scheme = strings.TrimSpace(strings.Split(xfProto, ",")[0]) + } + + host := strings.TrimSpace(c.Request.Host) + if xfHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); xfHost != "" { + host = strings.TrimSpace(strings.Split(xfHost, ",")[0]) + } + + return fmt.Sprintf("%s://%s/auth/callback", scheme, host) +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go new file mode 100644 index 00000000..8e8920a7 --- /dev/null +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -0,0 +1,273 @@ +package handler + +import ( + "context" + "io" + "log" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/middleware" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// GeminiV1BetaListModels proxies: +// GET /v1beta/models +func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { + apiKey, ok := middleware.GetApiKeyFromContext(c) + if !ok || apiKey == nil { + googleError(c, http.StatusUnauthorized, "Invalid API key") + return + } + if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + googleError(c, http.StatusBadRequest, "API key group platform is not gemini") + return + } + + account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) + if err != nil { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) + return + } + + res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models") + if err != nil { + googleError(c, http.StatusBadGateway, err.Error()) + return + } + if shouldFallbackGeminiModels(res) { + c.JSON(http.StatusOK, gemini.FallbackModelsList()) + return + } + writeUpstreamResponse(c, res) +} + +// GeminiV1BetaGetModel proxies: +// GET /v1beta/models/{model} +func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { + apiKey, ok := middleware.GetApiKeyFromContext(c) + if !ok || apiKey == nil { + googleError(c, http.StatusUnauthorized, "Invalid API key") + return + } + if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + googleError(c, http.StatusBadRequest, "API key group platform is not gemini") + return + } + + modelName := strings.TrimSpace(c.Param("model")) + if modelName == "" { + googleError(c, http.StatusBadRequest, "Missing model in URL") + return + } + + account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) + if err != nil { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) + return + } + + res, err := h.geminiCompatService.ForwardAIStudioGET(c.Request.Context(), account, "/v1beta/models/"+modelName) + if err != nil { + googleError(c, http.StatusBadGateway, err.Error()) + return + } + if shouldFallbackGeminiModels(res) { + c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) + return + } + writeUpstreamResponse(c, res) +} + +// GeminiV1BetaModels proxies Gemini native REST endpoints like: +// POST /v1beta/models/{model}:generateContent +// POST /v1beta/models/{model}:streamGenerateContent?alt=sse +func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { + apiKey, ok := middleware.GetApiKeyFromContext(c) + if !ok || apiKey == nil { + googleError(c, http.StatusUnauthorized, "Invalid API key") + return + } + user, ok := middleware.GetUserFromContext(c) + if !ok || user == nil { + googleError(c, http.StatusInternalServerError, "User context not found") + return + } + + if apiKey.Group == nil || apiKey.Group.Platform != model.PlatformGemini { + googleError(c, http.StatusBadRequest, "API key group platform is not gemini") + return + } + + modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/")) + if err != nil { + googleError(c, http.StatusNotFound, err.Error()) + return + } + + stream := action == "streamGenerateContent" + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + googleError(c, http.StatusBadRequest, "Failed to read request body") + return + } + if len(body) == 0 { + googleError(c, http.StatusBadRequest, "Request body is empty") + return + } + + // Get subscription (may be nil) + subscription, _ := middleware.GetSubscriptionFromContext(c) + + // For Gemini native API, do not send Claude-style ping frames. + geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone) + + // 0) wait queue check + maxWait := service.CalculateMaxWait(user.Concurrency) + canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), user.ID, maxWait) + if err != nil { + log.Printf("Increment wait count failed: %v", err) + } else if !canWait { + googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") + return + } + defer geminiConcurrency.DecrementWaitCount(c.Request.Context(), user.ID) + + // 1) user concurrency slot + streamStarted := false + userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, user, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + // 2) billing eligibility check (after wait) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil { + googleError(c, http.StatusForbidden, err.Error()) + return + } + + // 3) select account (sticky session based on request body) + sessionHash := h.gatewayService.GenerateSessionHash(body) + account, err := h.geminiCompatService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, modelName) + if err != nil { + googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) + return + } + + // 4) account concurrency slot + accountReleaseFunc, err := geminiConcurrency.AcquireAccountSlotWithWait(c, account, stream, &streamStarted) + if err != nil { + googleError(c, http.StatusTooManyRequests, err.Error()) + return + } + if accountReleaseFunc != nil { + defer accountReleaseFunc() + } + + // 5) forward (writes response to client) + result, err := h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + if err != nil { + // ForwardNative already wrote the response + log.Printf("Gemini native forward failed: %v", err) + return + } + + // 6) record usage async + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + ApiKey: apiKey, + User: user, + Account: account, + Subscription: subscription, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }() +} + +func parseGeminiModelAction(rest string) (model string, action string, err error) { + rest = strings.TrimSpace(rest) + if rest == "" { + return "", "", &pathParseError{"missing path"} + } + + // Standard: {model}:{action} + if i := strings.Index(rest, ":"); i > 0 && i < len(rest)-1 { + return rest[:i], rest[i+1:], nil + } + + // Fallback: {model}/{action} + if i := strings.Index(rest, "/"); i > 0 && i < len(rest)-1 { + return rest[:i], rest[i+1:], nil + } + + return "", "", &pathParseError{"invalid model action path"} +} + +type pathParseError struct{ msg string } + +func (e *pathParseError) Error() string { return e.msg } + +func googleError(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": googleapi.HTTPStatusToGoogleStatus(status), + }, + }) +} + +func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) { + if res == nil { + googleError(c, http.StatusBadGateway, "Empty upstream response") + return + } + for k, vv := range res.Headers { + // Avoid overriding content-length and hop-by-hop headers. + if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") { + continue + } + for _, v := range vv { + c.Writer.Header().Add(k, v) + } + } + contentType := res.Headers.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(res.StatusCode, contentType, res.Body) +} + +func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { + if res == nil { + return true + } + if res.StatusCode != http.StatusUnauthorized && res.StatusCode != http.StatusForbidden { + return false + } + if strings.Contains(strings.ToLower(res.Headers.Get("Www-Authenticate")), "insufficient_scope") { + return true + } + if strings.Contains(strings.ToLower(string(res.Body)), "insufficient authentication scopes") { + return true + } + if strings.Contains(strings.ToLower(string(res.Body)), "access_token_scope_insufficient") { + return true + } + return false +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index d1a73c43..847e64b1 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -314,6 +314,16 @@ func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, rep gateway.POST("/responses", h.OpenAIGateway.Responses) } + // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) + gemini := r.Group("/v1beta") + gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(s.ApiKey, s.Subscription)) + { + gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) + gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) + // Gin treats ":" as a param marker, but Gemini uses "{model}:{action}" in the same segment. + gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) + } + // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription), h.OpenAIGateway.Responses) } diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 287de3e3..62f1069c 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -27,6 +27,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/v1beta/") || strings.HasPrefix(path, "/setup/") || path == "/health" { c.Next()