feat(antigravity): 添加专用路由,支持仅使用 antigravity 账户
添加 /antigravity/v1/* 和 /antigravity/v1beta/* 路由: - 通过 ForcePlatform 中间件强制使用 antigravity 平台 - 跳过混合调度逻辑,仅调度 antigravity 账户 - 支持按分组优先查找,找不到时回退查询全部 antigravity 账户 修复 context key 类型不匹配问题: - middleware 和 service 统一使用字符串常量 "ctx_force_platform" - 解决 Go context.Value() 类型+值匹配导致的读取失败 其他改动: - 嵌入式前端中间件白名单添加 /antigravity/ 路径 - e2e 测试 Gemini 端点 URL 添加 endpointPrefix 支持
This commit is contained in:
@@ -126,8 +126,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 计算粘性会话hash
|
// 计算粘性会话hash
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
sessionHash := h.gatewayService.GenerateSessionHash(body)
|
||||||
|
|
||||||
|
// 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台
|
||||||
platform := ""
|
platform := ""
|
||||||
if apiKey.Group != nil {
|
if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok {
|
||||||
|
platform = forcePlatform
|
||||||
|
} else if apiKey.Group != nil {
|
||||||
platform = apiKey.Group.Platform
|
platform = apiKey.Group.Platform
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,19 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
|||||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||||
|
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||||
|
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 强制 antigravity 模式:直接返回静态模型列表
|
||||||
|
if forcePlatform == service.PlatformAntigravity {
|
||||||
|
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||||
@@ -63,7 +71,9 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
|||||||
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
googleError(c, http.StatusUnauthorized, "Invalid API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
// 检查平台:优先使用强制平台(/antigravity 路由),否则要求 gemini 分组
|
||||||
|
forcePlatform, hasForcePlatform := middleware.GetForcePlatformFromContext(c)
|
||||||
|
if !hasForcePlatform && (apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini) {
|
||||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -74,6 +84,12 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 强制 antigravity 模式:直接返回静态模型信息
|
||||||
|
if forcePlatform == service.PlatformAntigravity {
|
||||||
|
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
account, err := h.geminiCompatService.SelectAccountForAIStudioEndpoints(c.Request.Context(), apiKey.GroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
// 没有 gemini 账户,检查是否有 antigravity 账户可用
|
||||||
@@ -114,9 +130,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
// 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组
|
||||||
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
if !middleware.HasForcePlatform(c) {
|
||||||
return
|
if apiKey.Group == nil || apiKey.Group.Platform != service.PlatformGemini {
|
||||||
|
googleError(c, http.StatusBadRequest, "API key group platform is not gemini")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
modelName, action, err := parseGeminiModelAction(strings.TrimPrefix(c.Param("modelAction"), "/"))
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) {
|
|||||||
if stream {
|
if stream {
|
||||||
action = "streamGenerateContent"
|
action = "streamGenerateContent"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/v1beta/models/%s:%s", baseURL, model, action)
|
url := fmt.Sprintf("%s%s/v1beta/models/%s:%s", baseURL, endpointPrefix, model, action)
|
||||||
if stream {
|
if stream {
|
||||||
url += "?alt=sse"
|
url += "?alt=sse"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import "github.com/gin-gonic/gin"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
// ContextKey 定义上下文键类型
|
// ContextKey 定义上下文键类型
|
||||||
type ContextKey string
|
type ContextKey string
|
||||||
@@ -14,8 +18,43 @@ const (
|
|||||||
ContextKeyApiKey ContextKey = "api_key"
|
ContextKeyApiKey ContextKey = "api_key"
|
||||||
// ContextKeySubscription 订阅上下文键
|
// ContextKeySubscription 订阅上下文键
|
||||||
ContextKeySubscription ContextKey = "subscription"
|
ContextKeySubscription ContextKey = "subscription"
|
||||||
|
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||||
|
ContextKeyForcePlatform ContextKey = "force_platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key(供 Service 读取)
|
||||||
|
// 注意:service 包中也需要使用相同的字符串 "ctx_force_platform"
|
||||||
|
const ctxKeyForcePlatformStr = "ctx_force_platform"
|
||||||
|
|
||||||
|
// ForcePlatform 返回设置强制平台的中间件
|
||||||
|
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
||||||
|
func ForcePlatform(platform string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 设置到 request.Context,使用字符串 key 供 Service 层读取
|
||||||
|
ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
// 同时设置到 gin.Context,供 Handler 快速检查
|
||||||
|
c.Set(string(ContextKeyForcePlatform), platform)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
|
||||||
|
func HasForcePlatform(c *gin.Context) bool {
|
||||||
|
_, exists := c.Get(string(ContextKeyForcePlatform))
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
|
||||||
|
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
|
||||||
|
value, exists := c.Get(string(ContextKeyForcePlatform))
|
||||||
|
if !exists {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
platform, ok := value.(string)
|
||||||
|
return platform, ok
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorResponse 标准错误响应结构
|
// ErrorResponse 标准错误响应结构
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
Code string `json:"code"`
|
Code string `json:"code"`
|
||||||
|
|||||||
@@ -40,4 +40,24 @@ func RegisterGatewayRoutes(
|
|||||||
|
|
||||||
// OpenAI Responses API(不带v1前缀的别名)
|
// OpenAI Responses API(不带v1前缀的别名)
|
||||||
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
r.POST("/responses", gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||||
|
|
||||||
|
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||||
|
antigravityV1 := r.Group("/antigravity/v1")
|
||||||
|
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
|
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||||
|
{
|
||||||
|
antigravityV1.POST("/messages", h.Gateway.Messages)
|
||||||
|
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||||
|
antigravityV1.GET("/models", h.Gateway.Models)
|
||||||
|
antigravityV1.GET("/usage", h.Gateway.Usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
antigravityV1Beta := r.Group("/antigravity/v1beta")
|
||||||
|
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||||
|
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService))
|
||||||
|
{
|
||||||
|
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||||
|
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||||
|
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,10 @@ const (
|
|||||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置)
|
||||||
|
// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值
|
||||||
|
const ctxKeyForcePlatform = "ctx_force_platform"
|
||||||
|
|
||||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||||
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
@@ -294,9 +298,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
|
|
||||||
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
|
||||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 根据分组 platform 决定查询哪种账号
|
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||||
var platform string
|
var platform string
|
||||||
if groupID != nil {
|
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
|
||||||
|
if hasForcePlatform && forcePlatform != "" {
|
||||||
|
platform = forcePlatform
|
||||||
|
} else if groupID != nil {
|
||||||
|
// 根据分组 platform 决定查询哪种账号
|
||||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get group failed: %w", err)
|
return nil, fmt.Errorf("get group failed: %w", err)
|
||||||
@@ -308,11 +316,22 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
if platform == PlatformAnthropic || platform == PlatformGemini {
|
// 注意:强制平台模式不走混合调度
|
||||||
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
// antigravity 分组或无分组使用单平台选择
|
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
|
||||||
|
if hasForcePlatform && groupID != nil {
|
||||||
|
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
|
if err == nil {
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
// 分组中找不到,回退查询全部该平台账户
|
||||||
|
groupID = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,9 +72,13 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 根据分组 platform 决定查询哪种账号
|
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||||
var platform string
|
var platform string
|
||||||
if groupID != nil {
|
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
|
||||||
|
if hasForcePlatform && forcePlatform != "" {
|
||||||
|
platform = forcePlatform
|
||||||
|
} else if groupID != nil {
|
||||||
|
// 根据分组 platform 决定查询哪种账号
|
||||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get group failed: %w", err)
|
return nil, fmt.Errorf("get group failed: %w", err)
|
||||||
@@ -86,7 +90,8 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
useMixedScheduling := platform == PlatformGemini
|
// 注意:强制平台模式不走混合调度
|
||||||
|
useMixedScheduling := platform == PlatformGemini && !hasForcePlatform
|
||||||
var queryPlatforms []string
|
var queryPlatforms []string
|
||||||
if useMixedScheduling {
|
if useMixedScheduling {
|
||||||
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
queryPlatforms = []string{PlatformGemini, PlatformAntigravity}
|
||||||
@@ -118,11 +123,18 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询可调度账户
|
// 查询可调度账户(强制平台模式:优先按分组查找,找不到再查全部)
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
var err error
|
var err error
|
||||||
if groupID != nil {
|
if groupID != nil {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||||
|
}
|
||||||
|
// 强制平台模式下,分组中找不到账户时回退查询全部
|
||||||
|
if len(accounts) == 0 && hasForcePlatform {
|
||||||
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
if strings.HasPrefix(path, "/api/") ||
|
if strings.HasPrefix(path, "/api/") ||
|
||||||
strings.HasPrefix(path, "/v1/") ||
|
strings.HasPrefix(path, "/v1/") ||
|
||||||
strings.HasPrefix(path, "/v1beta/") ||
|
strings.HasPrefix(path, "/v1beta/") ||
|
||||||
|
strings.HasPrefix(path, "/antigravity/") ||
|
||||||
strings.HasPrefix(path, "/setup/") ||
|
strings.HasPrefix(path, "/setup/") ||
|
||||||
path == "/health" ||
|
path == "/health" ||
|
||||||
path == "/responses" {
|
path == "/responses" {
|
||||||
|
|||||||
Reference in New Issue
Block a user