feat(antigravity): 添加专用路由,支持仅使用 antigravity 账户

添加 /antigravity/v1/* 和 /antigravity/v1beta/* 路由:
- 通过 ForcePlatform 中间件强制使用 antigravity 平台
- 跳过混合调度逻辑,仅调度 antigravity 账户
- 支持按分组优先查找,找不到时回退查询全部 antigravity 账户

修复 context key 类型不匹配问题:
- middleware 和 service 统一使用字符串常量 "ctx_force_platform"
- 解决 Go context.Value() 类型+值匹配导致的读取失败

其他改动:
- 嵌入式前端中间件白名单添加 /antigravity/ 路径
- e2e 测试 Gemini 端点 URL 添加 endpointPrefix 支持
This commit is contained in:
song
2025-12-29 16:52:55 +08:00
parent 1ad29032d3
commit b31bfd53ab
8 changed files with 129 additions and 16 deletions

View File

@@ -126,8 +126,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则使用分组平台
platform := ""
if apiKey.Group != nil {
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
platform = forcePlatform
} else if apiKey.Group != nil {
platform = apiKey.Group.Platform
}

View File

@@ -25,11 +25,19 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
// 强制 antigravity 模式:直接返回静态模型列表
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
@@ -63,7 +71,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
googleError(c, http.StatusUnauthorized, "Invalid API key")
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
@@ -74,6 +84,12 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
return
}
// 强制 antigravity 模式:直接返回静态模型信息
if forcePlatform == service.PlatformAntigravity {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
if err != nil {
// 没有 gemini 账户,检查是否有 antigravity 账户可用
@@ -114,9 +130,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context否则要求 gemini 分组
if !middleware.HasForcePlatform(c) {
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
return
}
}
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))

View File

@@ -231,7 +231,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
if stream {
action = "streamGenerateContent"
}
url := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, model, action)
url := fmt.Sprintf("%s%s/v1beta/models/%s:%s", baseURL, endpointPrefix, model, action)
if stream {
url += "?alt=sse"
}

View File

@@ -1,6 +1,10 @@
package middleware
import "github.com/gin-gonic/gin"
import (
"context"
"github.com/gin-gonic/gin"
)
// ContextKey 定义上下文键类型
type ContextKey string
@@ -14,8 +18,43 @@ const (
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
ContextKeyForcePlatform ContextKey = "force_platform"
)
// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key供 Service 读取)
// 注意service 包中也需要使用相同的字符串 "ctx_force_platform"
const ctxKeyForcePlatformStr = "ctx_force_platform"
// ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context供 Service 使用)和 gin.Context供 Handler 快速检查)
func ForcePlatform(platform string) gin.HandlerFunc {
return func(c *gin.Context) {
// 设置到 request.Context使用字符串 key 供 Service 层读取
ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform)
c.Request = c.Request.WithContext(ctx)
// 同时设置到 gin.Context供 Handler 快速检查
c.Set(string(ContextKeyForcePlatform), platform)
c.Next()
}
}
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
func HasForcePlatform(c *gin.Context) bool {
_, exists := c.Get(string(ContextKeyForcePlatform))
return exists
}
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyForcePlatform))
if !exists {
return "", false
}
platform, ok := value.(string)
return platform, ok
}
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`

View File

@@ -40,4 +40,24 @@ func RegisterGatewayRoutes(
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
antigravityV1.GET("/models", h.Gateway.Models)
antigravityV1.GET("/usage", h.Gateway.Usage)
}
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
}

View File

@@ -30,6 +30,10 @@ const (
stickySessionTTL = time.Hour // 粘性会话TTL
)
// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置)
// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值
const ctxKeyForcePlatform = "ctx_force_platform"
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`)
@@ -294,9 +298,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 根据分组 platform 决定查询哪种账号
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
if groupID != nil {
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
@@ -308,11 +316,22 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
if platform == PlatformAnthropic || platform == PlatformGemini {
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// antigravity 分组或无分组使用单平台选择
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}

View File

@@ -72,9 +72,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
}
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 根据分组 platform 决定查询哪种账号
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
if groupID != nil {
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
@@ -86,7 +90,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
useMixedScheduling := platform == PlatformGemini
// 注意:强制平台模式不走混合调度
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
var queryPlatforms []string
if useMixedScheduling {
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
@@ -118,11 +123,18 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
// 查询可调度账户
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 强制平台模式下,分组中找不到账户时回退查询全部
if len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
}

View File

@@ -28,6 +28,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
path == "/responses" {