diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 5dc03b6d..524c6b6d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go index 82b30ee4..29d7cc41 100644 --- a/backend/internal/handler/gemini_v1beta_handler_test.go +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -3,6 +3,7 @@ package handler import ( + "net/http" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { }) } } + +func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res)) +} + +func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.False(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} + +func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{ + StatusCode: http.StatusForbidden, + Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}}, + Body: []byte("insufficient authentication scopes"), + } + require.True(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 882d2ebd..fac79d18 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -2,6 +2,8 @@ // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). package gemini +import "strings" + type Model struct { Name string `json:"name"` DisplayName string `json:"displayName,omitempty"` @@ -23,10 +25,27 @@ func DefaultModels() []Model { {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } +func HasFallbackModel(model string) bool { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return false + } + if !strings.HasPrefix(trimmed, "models/") { + trimmed = "models/" + trimmed + } + for _, model := range DefaultModels() { + if model.Name == trimmed { + return true + } + } + return false +} + func FallbackModelsList() ModelsListResponse { return ModelsListResponse{Models: DefaultModels()} } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go index b80047fb..1d20c0e6 100644 --- a/backend/internal/pkg/gemini/models_test.go +++ b/backend/internal/pkg/gemini/models_test.go @@ -2,7 +2,7 @@ package gemini import "testing" -func TestDefaultModels_ContainsImageModels(t *testing.T) { +func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) { t.Parallel() models := DefaultModels() @@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { required := []string{ "models/gemini-2.5-flash-image", + "models/gemini-3.1-pro-preview-customtools", "models/gemini-3.1-flash-image", } @@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { } } } + +func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) { + t.Parallel() + + if !HasFallbackModel("gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected customtools model to exist in fallback catalog") + } + if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected prefixed customtools model to exist in fallback catalog") + } + if HasFallbackModel("gemini-unknown") { + t.Fatalf("did not expect unknown model to exist in fallback catalog") + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 10a8e880..512195e3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -515,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st } } +func normalizeRequestedModelForLookup(platform, requestedModel string) string { + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return "" + } + if platform != PlatformGemini && platform != PlatformAntigravity { + return trimmed + } + if trimmed == "gemini-3.1-pro-preview-customtools" { + return "gemini-3.1-pro-preview" + } + return trimmed +} + +func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool { + if requestedModel == "" { + return false + } + if _, exists := mapping[requestedModel]; exists { + return true + } + for pattern := range mapping { + if matchWildcard(pattern, requestedModel) { + return true + } + } + return false +} + +func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) { + if requestedModel == "" { + return "", false + } + if mappedModel, exists := mapping[requestedModel]; exists { + return mappedModel, true + } + return matchWildcardMappingResult(mapping, requestedModel) +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -522,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool { if len(mapping) == 0 { return true // 无映射 = 允许所有 } - // 精确匹配 - if _, exists := mapping[requestedModel]; exists { + if mappingSupportsRequestedModel(mapping, requestedModel) { return true } - // 通配符匹配 - for pattern := range mapping { - if matchWildcard(pattern, requestedModel) { - return true - } - } - return false + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized) } // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) @@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, if len(mapping) == 0 { return requestedModel, false } - // 精确匹配优先 - if mappedModel, exists := mapping[requestedModel]; exists { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched { return mappedModel, true } - // 通配符匹配(最长优先) - return matchWildcardMappingResult(mapping, requestedModel) + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + if normalized != requestedModel { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched { + return mappedModel, true + } + } + return requestedModel, false } func (a *Account) GetBaseURL() string { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 0d7ffffa..d903b940 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected bool @@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { requestedModel: "claude-opus-4-5-thinking", expected: true, }, + { + name: "gemini customtools alias matches normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: true, + }, { name: "wildcard match not supported", credentials: map[string]any{ @@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.IsModelSupported(tt.requestedModel) @@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected string @@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", }, + { + name: "no mapping preserves gemini customtools model", + platform: PlatformGemini, + credentials: nil, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, // 精确匹配 { @@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { }, // 无匹配返回原始模型 + { + name: "gemini customtools alias resolves through normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview", + }, + { + name: "gemini customtools exact mapping wins over normalized fallback", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, { name: "no match returns original", credentials: map[string]any{ @@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.GetMappedModel(tt.requestedModel) @@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expectedModel string @@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { expectedModel: "gpt-5.4", expectedMatch: true, }, + { + name: "gemini customtools alias reports normalized match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview", + expectedMatch: true, + }, + { + name: "gemini customtools exact mapping reports exact match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview-customtools", + expectedMatch: true, + }, { name: "missing mapping reports unmatched", credentials: map[string]any{ @@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 1dbe9870..a29000e7 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { requestedModel: "gemini-2.5-flash", expected: "gemini-2.5-flash", }, + { + name: "customtools alias falls back to normalized preview mapping", + modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"}, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-high", + }, } for _, tt := range tests {