diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9c77bafa..59ab429c 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -126,8 +126,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 计算粘性会话hash sessionHash := h.gatewayService.GenerateSessionHash(body) + // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 platform := "" - if apiKey.Group != nil { + if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forcePlatform + } else if apiKey.Group != nil { platform = apiKey.Group.Platform } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 613d4c86..ea1bdf5a 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -25,11 +25,19 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { googleError(c, http.StatusUnauthorized, "Invalid API key") return } - if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { + // 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组 + forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c) + if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } + // 强制 antigravity 模式:直接返回静态模型列表 + if forcePlatform == service.PlatformAntigravity { + c.JSON(http.StatusOK, gemini.FallbackModelsList()) + return + } + account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) if err != nil { // 没有 gemini 账户,检查是否有 antigravity 账户可用 @@ -63,7 +71,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusUnauthorized, "Invalid API key") return } - if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { + // 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组 + forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c) + if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) { googleError(c, http.StatusBadRequest, "API key group platform is not gemini") return } @@ -74,6 +84,12 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { return } + // 强制 antigravity 模式:直接返回静态模型信息 + if forcePlatform == service.PlatformAntigravity { + c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) + return + } + account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID) if err != nil { // 没有 gemini 账户,检查是否有 antigravity 账户可用 @@ -114,9 +130,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } - if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { - googleError(c, http.StatusBadRequest, "API key group platform is not gemini") - return + // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 + if !middleware.HasForcePlatform(c) { + if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini { + googleError(c, http.StatusBadRequest, "API key group platform is not gemini") + return + } } modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/")) diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ae0b138a..05cdc85f 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -231,7 +231,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { if stream { action = "streamGenerateContent" } - url := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, model, action) + url := fmt.Sprintf("%s%s/v1beta/models/%s:%s", baseURL, endpointPrefix, model, action) if stream { url += "?alt=sse" } diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index 1af8dbef..45643164 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -1,6 +1,10 @@ package middleware -import "github.com/gin-gonic/gin" +import ( + "context" + + "github.com/gin-gonic/gin" +) // ContextKey 定义上下文键类型 type ContextKey string @@ -14,8 +18,43 @@ const ( ContextKeyApiKey ContextKey = "api_key" // ContextKeySubscription 订阅上下文键 ContextKeySubscription ContextKey = "subscription" + // ContextKeyForcePlatform 强制平台(用于 /antigravity 路由) + ContextKeyForcePlatform ContextKey = "force_platform" ) +// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key(供 Service 读取) +// 注意:service 包中也需要使用相同的字符串 "ctx_force_platform" +const ctxKeyForcePlatformStr = "ctx_force_platform" + +// ForcePlatform 返回设置强制平台的中间件 +// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查) +func ForcePlatform(platform string) gin.HandlerFunc { + return func(c *gin.Context) { + // 设置到 request.Context,使用字符串 key 供 Service 层读取 + ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform) + c.Request = c.Request.WithContext(ctx) + // 同时设置到 gin.Context,供 Handler 快速检查 + c.Set(string(ContextKeyForcePlatform), platform) + c.Next() + } +} + +// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查) +func HasForcePlatform(c *gin.Context) bool { + _, exists := c.Get(string(ContextKeyForcePlatform)) + return exists +} + +// GetForcePlatformFromContext 从 gin.Context 获取强制平台 +func GetForcePlatformFromContext(c *gin.Context) (string, bool) { + value, exists := c.Get(string(ContextKeyForcePlatform)) + if !exists { + return "", false + } + platform, ok := value.(string) + return platform, ok +} + // ErrorResponse 标准错误响应结构 type ErrorResponse struct { Code string `json:"code"` diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index eab36ef8..2bf388f8 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -40,4 +40,24 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + + // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) + antigravityV1 := r.Group("/antigravity/v1") + antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) + antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) + { + antigravityV1.POST("/messages", h.Gateway.Messages) + antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens) + antigravityV1.GET("/models", h.Gateway.Models) + antigravityV1.GET("/usage", h.Gateway.Usage) + } + + antigravityV1Beta := r.Group("/antigravity/v1beta") + antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) + antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) + { + antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) + antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) + antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) + } } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6b286599..641962ea 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -30,6 +30,10 @@ const ( stickySessionTTL = time.Hour // 粘性会话TTL ) +// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置) +// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值 +const ctxKeyForcePlatform = "ctx_force_platform" + // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var sseDataRe = regexp.MustCompile(`^data:\s*`) @@ -294,9 +298,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 根据分组 platform 决定查询哪种账号 + // 优先检查 context 中的强制平台(/antigravity 路由) var platform string - if groupID != nil { + forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + platform = forcePlatform + } else if groupID != nil { + // 根据分组 platform 决定查询哪种账号 group, err := s.groupRepo.GetByID(ctx, *groupID) if err != nil { return nil, fmt.Errorf("get group failed: %w", err) @@ -308,11 +316,22 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - if platform == PlatformAnthropic || platform == PlatformGemini { + // 注意:强制平台模式不走混合调度 + if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } - // antigravity 分组或无分组使用单平台选择 + // 强制平台模式:优先按分组查找,找不到再查全部该平台账户 + if hasForcePlatform && groupID != nil { + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err == nil { + return account, nil + } + // 分组中找不到,回退查询全部该平台账户 + groupID = nil + } + + // antigravity 分组、强制平台模式或无分组使用单平台选择 return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2f92abfc..025ca888 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -72,9 +72,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, } func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 根据分组 platform 决定查询哪种账号 + // 优先检查 context 中的强制平台(/antigravity 路由) var platform string - if groupID != nil { + forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + platform = forcePlatform + } else if groupID != nil { + // 根据分组 platform 决定查询哪种账号 group, err := s.groupRepo.GetByID(ctx, *groupID) if err != nil { return nil, fmt.Errorf("get group failed: %w", err) @@ -86,7 +90,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } // gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - useMixedScheduling := platform == PlatformGemini + // 注意:强制平台模式不走混合调度 + useMixedScheduling := platform == PlatformGemini && !hasForcePlatform var queryPlatforms []string if useMixedScheduling { queryPlatforms = []string{PlatformGemini, PlatformAntigravity} @@ -118,11 +123,18 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } - // 查询可调度账户 + // 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部) var accounts []Account var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + // 强制平台模式下,分组中找不到账户时回退查询全部 + if len(accounts) == 0 && hasForcePlatform { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } } else { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) } diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 4bf46897..0ee8d614 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -28,6 +28,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || path == "/responses" {