Merge pull request #1420 from YanzheL/fix/1202-gemini-customtools-404
Fix Gemini CLI 404s for gemini-3.1-pro-preview-customtools
This commit is contained in:
@@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
googleError(c, http.StatusBadGateway, err.Error())
|
||||
return
|
||||
}
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
if shouldFallbackGeminiModel(modelName, res) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
@@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool {
|
||||
if shouldFallbackGeminiModels(res) {
|
||||
return true
|
||||
}
|
||||
if res == nil || res.StatusCode != http.StatusNotFound {
|
||||
return false
|
||||
}
|
||||
return gemini.HasFallbackModel(modelName)
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||
require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res))
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound}
|
||||
require.False(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||
}
|
||||
|
||||
func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := &service.UpstreamHTTPResult{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}},
|
||||
Body: []byte("insufficient authentication scopes"),
|
||||
}
|
||||
require.True(t, shouldFallbackGeminiModel("gemini-future-model", res))
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
@@ -23,10 +25,27 @@ func DefaultModels() []Model {
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
func HasFallbackModel(model string) bool {
|
||||
trimmed := strings.TrimSpace(model)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(trimmed, "models/") {
|
||||
trimmed = "models/" + trimmed
|
||||
}
|
||||
for _, model := range DefaultModels() {
|
||||
if model.Name == trimmed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func FallbackModelsList() ModelsListResponse {
|
||||
return ModelsListResponse{Models: DefaultModels()}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package gemini
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
@@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
|
||||
required := []string{
|
||||
"models/gemini-2.5-flash-image",
|
||||
"models/gemini-3.1-pro-preview-customtools",
|
||||
"models/gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
@@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if !HasFallbackModel("gemini-3.1-pro-preview-customtools") {
|
||||
t.Fatalf("expected customtools model to exist in fallback catalog")
|
||||
}
|
||||
if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") {
|
||||
t.Fatalf("expected prefixed customtools model to exist in fallback catalog")
|
||||
}
|
||||
if HasFallbackModel("gemini-unknown") {
|
||||
t.Fatalf("did not expect unknown model to exist in fallback catalog")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -515,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRequestedModelForLookup(platform, requestedModel string) string {
|
||||
trimmed := strings.TrimSpace(requestedModel)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if platform != PlatformGemini && platform != PlatformAntigravity {
|
||||
return trimmed
|
||||
}
|
||||
if trimmed == "gemini-3.1-pro-preview-customtools" {
|
||||
return "gemini-3.1-pro-preview"
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool {
|
||||
if requestedModel == "" {
|
||||
return false
|
||||
}
|
||||
if _, exists := mapping[requestedModel]; exists {
|
||||
return true
|
||||
}
|
||||
for pattern := range mapping {
|
||||
if matchWildcard(pattern, requestedModel) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) {
|
||||
if requestedModel == "" {
|
||||
return "", false
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel, true
|
||||
}
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否在 model_mapping 中(支持通配符)
|
||||
// 如果未配置 mapping,返回 true(允许所有模型)
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
@@ -522,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
if len(mapping) == 0 {
|
||||
return true // 无映射 = 允许所有
|
||||
}
|
||||
// 精确匹配
|
||||
if _, exists := mapping[requestedModel]; exists {
|
||||
if mappingSupportsRequestedModel(mapping, requestedModel) {
|
||||
return true
|
||||
}
|
||||
// 通配符匹配
|
||||
for pattern := range mapping {
|
||||
if matchWildcard(pattern, requestedModel) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
|
||||
return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized)
|
||||
}
|
||||
|
||||
// GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配)
|
||||
@@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel, false
|
||||
}
|
||||
// 精确匹配优先
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
|
||||
return mappedModel, true
|
||||
}
|
||||
// 通配符匹配(最长优先)
|
||||
return matchWildcardMappingResult(mapping, requestedModel)
|
||||
normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel)
|
||||
if normalized != requestedModel {
|
||||
if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched {
|
||||
return mappedModel, true
|
||||
}
|
||||
}
|
||||
return requestedModel, false
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
|
||||
@@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
func TestAccountIsModelSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expected bool
|
||||
@@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "gemini customtools alias matches normalized mapping",
|
||||
platform: PlatformGemini,
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match not supported",
|
||||
credentials: map[string]any{
|
||||
@@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: tt.platform,
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
result := account.IsModelSupported(tt.requestedModel)
|
||||
@@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) {
|
||||
func TestAccountGetMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expected string
|
||||
@@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "no mapping preserves gemini customtools model",
|
||||
platform: PlatformGemini,
|
||||
credentials: nil,
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
@@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
{
|
||||
name: "gemini customtools alias resolves through normalized mapping",
|
||||
platform: PlatformGemini,
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview",
|
||||
},
|
||||
{
|
||||
name: "gemini customtools exact mapping wins over normalized fallback",
|
||||
platform: PlatformGemini,
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
{
|
||||
name: "no match returns original",
|
||||
credentials: map[string]any{
|
||||
@@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: tt.platform,
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
result := account.GetMappedModel(tt.requestedModel)
|
||||
@@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) {
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
@@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) {
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "gemini customtools alias reports normalized match",
|
||||
platform: PlatformGemini,
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expectedModel: "gemini-3.1-pro-preview",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "gemini customtools exact mapping reports exact match",
|
||||
platform: PlatformGemini,
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3.1-pro-preview": "gemini-3.1-pro-preview",
|
||||
"gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expectedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
@@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: tt.platform,
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
|
||||
@@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "customtools alias falls back to normalized preview mapping",
|
||||
modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"},
|
||||
requestedModel: "gemini-3.1-pro-preview-customtools",
|
||||
expected: "gemini-3.1-pro-high",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user