diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 5dc03b6d..524c6b6d 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModels(res) { + if shouldFallbackGeminiModel(modelName, res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -674,6 +674,16 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } +func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { + if shouldFallbackGeminiModels(res) { + return true + } + if res == nil || res.StatusCode != http.StatusNotFound { + return false + } + return gemini.HasFallbackModel(modelName) +} + // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/gemini_v1beta_handler_test.go b/backend/internal/handler/gemini_v1beta_handler_test.go index 82b30ee4..29d7cc41 100644 --- a/backend/internal/handler/gemini_v1beta_handler_test.go +++ b/backend/internal/handler/gemini_v1beta_handler_test.go @@ -3,6 +3,7 @@ package handler import ( + "net/http" "testing" "github.com/Wei-Shaw/sub2api/internal/service" @@ -141,3 +142,28 @@ func TestGeminiV1BetaHandler_GetModelAntigravityFallback(t *testing.T) { }) } } + +func TestShouldFallbackGeminiModel_KnownFallbackOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.True(t, shouldFallbackGeminiModel("gemini-3.1-pro-preview-customtools", res)) +} + +func TestShouldFallbackGeminiModel_UnknownModelOn404(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{StatusCode: http.StatusNotFound} + require.False(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} + +func TestShouldFallbackGeminiModel_DelegatesScopeFallback(t *testing.T) { + t.Parallel() + + res := &service.UpstreamHTTPResult{ + StatusCode: http.StatusForbidden, + Headers: http.Header{"Www-Authenticate": []string{"Bearer error=\"insufficient_scope\""}}, + Body: []byte("insufficient authentication scopes"), + } + require.True(t, shouldFallbackGeminiModel("gemini-future-model", res)) +} diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 882d2ebd..fac79d18 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -2,6 +2,8 @@ // It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes). package gemini +import "strings" + type Model struct { Name string `json:"name"` DisplayName string `json:"displayName,omitempty"` @@ -23,10 +25,27 @@ func DefaultModels() []Model { {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview-customtools", SupportedGenerationMethods: methods}, {Name: "models/gemini-3.1-flash-image", SupportedGenerationMethods: methods}, } } +func HasFallbackModel(model string) bool { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return false + } + if !strings.HasPrefix(trimmed, "models/") { + trimmed = "models/" + trimmed + } + for _, model := range DefaultModels() { + if model.Name == trimmed { + return true + } + } + return false +} + func FallbackModelsList() ModelsListResponse { return ModelsListResponse{Models: DefaultModels()} } diff --git a/backend/internal/pkg/gemini/models_test.go b/backend/internal/pkg/gemini/models_test.go index b80047fb..1d20c0e6 100644 --- a/backend/internal/pkg/gemini/models_test.go +++ b/backend/internal/pkg/gemini/models_test.go @@ -2,7 +2,7 @@ package gemini import "testing" -func TestDefaultModels_ContainsImageModels(t *testing.T) { +func TestDefaultModels_ContainsFallbackCatalogModels(t *testing.T) { t.Parallel() models := DefaultModels() @@ -13,6 +13,7 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { required := []string{ "models/gemini-2.5-flash-image", + "models/gemini-3.1-pro-preview-customtools", "models/gemini-3.1-flash-image", } @@ -26,3 +27,17 @@ func TestDefaultModels_ContainsImageModels(t *testing.T) { } } } + +func TestHasFallbackModel_RecognizesCustomtoolsModel(t *testing.T) { + t.Parallel() + + if !HasFallbackModel("gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected customtools model to exist in fallback catalog") + } + if !HasFallbackModel("models/gemini-3.1-pro-preview-customtools") { + t.Fatalf("expected prefixed customtools model to exist in fallback catalog") + } + if HasFallbackModel("gemini-unknown") { + t.Fatalf("did not expect unknown model to exist in fallback catalog") + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 10a8e880..512195e3 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -515,6 +515,45 @@ func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []st } } +func normalizeRequestedModelForLookup(platform, requestedModel string) string { + trimmed := strings.TrimSpace(requestedModel) + if trimmed == "" { + return "" + } + if platform != PlatformGemini && platform != PlatformAntigravity { + return trimmed + } + if trimmed == "gemini-3.1-pro-preview-customtools" { + return "gemini-3.1-pro-preview" + } + return trimmed +} + +func mappingSupportsRequestedModel(mapping map[string]string, requestedModel string) bool { + if requestedModel == "" { + return false + } + if _, exists := mapping[requestedModel]; exists { + return true + } + for pattern := range mapping { + if matchWildcard(pattern, requestedModel) { + return true + } + } + return false +} + +func resolveRequestedModelInMapping(mapping map[string]string, requestedModel string) (mappedModel string, matched bool) { + if requestedModel == "" { + return "", false + } + if mappedModel, exists := mapping[requestedModel]; exists { + return mappedModel, true + } + return matchWildcardMappingResult(mapping, requestedModel) +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -522,17 +561,11 @@ func (a *Account) IsModelSupported(requestedModel string) bool { if len(mapping) == 0 { return true // 无映射 = 允许所有 } - // 精确匹配 - if _, exists := mapping[requestedModel]; exists { + if mappingSupportsRequestedModel(mapping, requestedModel) { return true } - // 通配符匹配 - for pattern := range mapping { - if matchWildcard(pattern, requestedModel) { - return true - } - } - return false + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + return normalized != requestedModel && mappingSupportsRequestedModel(mapping, normalized) } // GetMappedModel 获取映射后的模型名(支持通配符,最长优先匹配) @@ -549,12 +582,16 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string, if len(mapping) == 0 { return requestedModel, false } - // 精确匹配优先 - if mappedModel, exists := mapping[requestedModel]; exists { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched { return mappedModel, true } - // 通配符匹配(最长优先) - return matchWildcardMappingResult(mapping, requestedModel) + normalized := normalizeRequestedModelForLookup(a.Platform, requestedModel) + if normalized != requestedModel { + if mappedModel, matched := resolveRequestedModelInMapping(mapping, normalized); matched { + return mappedModel, true + } + } + return requestedModel, false } func (a *Account) GetBaseURL() string { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 0d7ffffa..d903b940 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -133,6 +133,7 @@ func TestMatchWildcardMappingResult(t *testing.T) { func TestAccountIsModelSupported(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected bool @@ -184,6 +185,17 @@ func TestAccountIsModelSupported(t *testing.T) { requestedModel: "claude-opus-4-5-thinking", expected: true, }, + { + name: "gemini customtools alias matches normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: true, + }, { name: "wildcard match not supported", credentials: map[string]any{ @@ -199,6 +211,7 @@ func TestAccountIsModelSupported(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.IsModelSupported(tt.requestedModel) @@ -212,6 +225,7 @@ func TestAccountIsModelSupported(t *testing.T) { func TestAccountGetMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expected string @@ -223,6 +237,13 @@ func TestAccountGetMappedModel(t *testing.T) { requestedModel: "claude-sonnet-4-5", expected: "claude-sonnet-4-5", }, + { + name: "no mapping preserves gemini customtools model", + platform: PlatformGemini, + credentials: nil, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, // 精确匹配 { @@ -250,6 +271,29 @@ func TestAccountGetMappedModel(t *testing.T) { }, // 无匹配返回原始模型 + { + name: "gemini customtools alias resolves through normalized mapping", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview", + }, + { + name: "gemini customtools exact mapping wins over normalized fallback", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-preview-customtools", + }, { name: "no match returns original", credentials: map[string]any{ @@ -265,6 +309,7 @@ func TestAccountGetMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } result := account.GetMappedModel(tt.requestedModel) @@ -278,6 +323,7 @@ func TestAccountGetMappedModel(t *testing.T) { func TestAccountResolveMappedModel(t *testing.T) { tests := []struct { name string + platform string credentials map[string]any requestedModel string expectedModel string @@ -312,6 +358,31 @@ func TestAccountResolveMappedModel(t *testing.T) { expectedModel: "gpt-5.4", expectedMatch: true, }, + { + name: "gemini customtools alias reports normalized match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview", + expectedMatch: true, + }, + { + name: "gemini customtools exact mapping reports exact match", + platform: PlatformGemini, + credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3.1-pro-preview": "gemini-3.1-pro-preview", + "gemini-3.1-pro-preview-customtools": "gemini-3.1-pro-preview-customtools", + }, + }, + requestedModel: "gemini-3.1-pro-preview-customtools", + expectedModel: "gemini-3.1-pro-preview-customtools", + expectedMatch: true, + }, { name: "missing mapping reports unmatched", credentials: map[string]any{ @@ -328,6 +399,7 @@ func TestAccountResolveMappedModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { account := &Account{ + Platform: tt.platform, Credentials: tt.credentials, } mappedModel, matched := account.ResolveMappedModel(tt.requestedModel) diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 1dbe9870..a29000e7 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -268,6 +268,12 @@ func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) { requestedModel: "gemini-2.5-flash", expected: "gemini-2.5-flash", }, + { + name: "customtools alias falls back to normalized preview mapping", + modelMapping: map[string]any{"gemini-3.1-pro-preview": "gemini-3.1-pro-high"}, + requestedModel: "gemini-3.1-pro-preview-customtools", + expected: "gemini-3.1-pro-high", + }, } for _, tt := range tests { diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index d0534d8c..21b4874e 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -85,7 +85,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact if v, ok := reqBody["model"].(string); ok { model = v } - normalizedModel := normalizeCodexModel(model) + normalizedModel := strings.TrimSpace(model) if normalizedModel != "" { if model != normalizedModel { reqBody["model"] = normalizedModel diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index eab88c09..889ac615 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -246,6 +246,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt 5.3 codex spark": "gpt-5.3-codex", "gpt-5.3-codex-spark-high": "gpt-5.3-codex", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", "gpt 5.3 codex": "gpt-5.3-codex", @@ -256,6 +257,34 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { + reqBody := map[string]any{ + "model": "gpt-5.3-codex-spark", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + store, ok := reqBody["store"].(bool) + require.True(t, ok) + require.False(t, store) +} + +func TestApplyCodexOAuthTransform_TrimmedModelWithoutPolicyRewrite(t *testing.T) { + reqBody := map[string]any{ + "model": " gpt-5.3-codex-spark ", + "input": []any{}, + } + + result := applyCodexOAuthTransform(reqBody, false, false) + + require.Equal(t, "gpt-5.3-codex-spark", reqBody["model"]) + require.Equal(t, "gpt-5.3-codex-spark", result.NormalizedModel) + require.True(t, result.Modified) +} + func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { // Codex CLI 场景:已有 instructions 时不修改 diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 88e16a4d..46381838 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,8 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch normalizeCodexModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex": + switch resolveOpenAIUpstreamModel(strings.TrimSpace(model)) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return true default: return false @@ -23,9 +23,9 @@ func deriveCompatPromptCacheKey(req *apicompat.ChatCompletionsRequest, mappedMod return "" } - normalizedModel := normalizeCodexModel(strings.TrimSpace(mappedModel)) + normalizedModel := resolveOpenAIUpstreamModel(strings.TrimSpace(mappedModel)) if normalizedModel == "" { - normalizedModel = normalizeCodexModel(strings.TrimSpace(req.Model)) + normalizedModel = resolveOpenAIUpstreamModel(strings.TrimSpace(req.Model)) } if normalizedModel == "" { normalizedModel = strings.TrimSpace(req.Model) diff --git a/backend/internal/service/openai_compat_prompt_cache_key_test.go b/backend/internal/service/openai_compat_prompt_cache_key_test.go index eb9148de..6ca3e85c 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key_test.go +++ b/backend/internal/service/openai_compat_prompt_cache_key_test.go @@ -17,6 +17,7 @@ func TestShouldAutoInjectPromptCacheKeyForCompat(t *testing.T) { require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.4")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3")) require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex")) + require.True(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-5.3-codex-spark")) require.False(t, shouldAutoInjectPromptCacheKeyForCompat("gpt-4o")) } @@ -62,3 +63,17 @@ func TestDeriveCompatPromptCacheKey_DiffersAcrossSessions(t *testing.T) { k2 := deriveCompatPromptCacheKey(req2, "gpt-5.4") require.NotEqual(t, k1, k2, "different first user messages should yield different keys") } + +func TestDeriveCompatPromptCacheKey_UsesResolvedSparkFamily(t *testing.T) { + req := &apicompat.ChatCompletionsRequest{ + Model: "gpt-5.3-codex-spark", + Messages: []apicompat.ChatMessage{ + {Role: "user", Content: mustRawJSON(t, `"Question A"`)}, + }, + } + + k1 := deriveCompatPromptCacheKey(req, "gpt-5.3-codex-spark") + k2 := deriveCompatPromptCacheKey(req, " openai/gpt-5.3-codex-spark ") + require.NotEmpty(t, k1) + require.Equal(t, k1, k2, "resolved spark family should derive a stable compat cache key") +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index a442da33..1d5bf0d0 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -45,12 +45,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( // 2. Resolve model mapping early so compat prompt_cache_key injection can // derive a stable seed from the final upstream model family. - mappedModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + billingModel := resolveOpenAIForwardModel(account, originalModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) promptCacheKey = strings.TrimSpace(promptCacheKey) compatPromptCacheInjected := false - if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(mappedModel) { - promptCacheKey = deriveCompatPromptCacheKey(&chatReq, mappedModel) + if promptCacheKey == "" && account.Type == AccountTypeOAuth && shouldAutoInjectPromptCacheKeyForCompat(upstreamModel) { + promptCacheKey = deriveCompatPromptCacheKey(&chatReq, upstreamModel) compatPromptCacheInjected = promptCacheKey != "" } @@ -60,12 +61,13 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( if err != nil { return nil, fmt.Errorf("convert chat completions to responses: %w", err) } - responsesReq.Model = mappedModel + responsesReq.Model = upstreamModel logFields := []zap.Field{ zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", clientStream), } if compatPromptCacheInjected { @@ -88,6 +90,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -180,9 +185,9 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, mappedModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) } else { - result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -224,7 +229,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -295,8 +301,8 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -308,7 +314,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, includeUsage bool, startTime time.Time, ) (*OpenAIForwardResult, error) { @@ -343,8 +350,8 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 3df91b56..8c389556 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -41,6 +41,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } originalModel := anthropicReq.Model applyOpenAICompatModelNormalization(&anthropicReq) + normalizedModel := anthropicReq.Model clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses @@ -60,13 +61,16 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } // 3. Model mapping - mappedModel := resolveOpenAIForwardModel(account, anthropicReq.Model, defaultMappedModel) - responsesReq.Model = mappedModel + billingModel := resolveOpenAIForwardModel(account, normalizedModel, defaultMappedModel) + upstreamModel := resolveOpenAIUpstreamModel(billingModel) + responsesReq.Model = upstreamModel logger.L().Debug("openai messages: model mapping applied", zap.Int64("account_id", account.ID), zap.String("original_model", originalModel), - zap.String("mapped_model", mappedModel), + zap.String("normalized_model", normalizedModel), + zap.String("billing_model", billingModel), + zap.String("upstream_model", upstreamModel), zap.Bool("stream", isStream), ) @@ -82,6 +86,9 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey } else if promptCacheKey != "" { @@ -182,10 +189,10 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } else { // Client wants JSON: buffer the streaming response and assemble a JSON reply. - result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleAnthropicBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } // Propagate ServiceTier and ReasoningEffort to result for billing @@ -230,7 +237,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -303,8 +311,8 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: false, Duration: time.Since(startTime), }, nil @@ -319,7 +327,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( resp *http.Response, c *gin.Context, originalModel string, - mappedModel string, + billingModel string, + upstreamModel string, startTime time.Time, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -352,8 +361,8 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( RequestID: requestID, Usage: usage, Model: originalModel, - BillingModel: mappedModel, - UpstreamModel: mappedModel, + BillingModel: billingModel, + UpstreamModel: upstreamModel, Stream: true, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index b9f42cd7..3355d3d5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1818,29 +1818,29 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 对所有请求执行模型映射(包含 Codex CLI)。 - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) - reqBody["model"] = mappedModel + billingModel := account.GetMappedModel(reqModel) + if billingModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) + reqBody["model"] = billingModel bodyModified = true - markPatchSet("model", mappedModel) + markPatchSet("model", billingModel) } + upstreamModel := billingModel // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { - normalizedModel := normalizeCodexModel(model) - if normalizedModel != "" && normalizedModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, normalizedModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = normalizedModel - mappedModel = normalizedModel + upstreamModel = resolveOpenAIUpstreamModel(model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel bodyModified = true - markPatchSet("model", normalizedModel) + markPatchSet("model", upstreamModel) } // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 // 确保高版本模型向低版本模型映射不报错 - if !SupportsVerbosity(normalizedModel) { + if !SupportsVerbosity(upstreamModel) { if text, ok := reqBody["text"].(map[string]any); ok { delete(text, "verbosity") } @@ -1864,7 +1864,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() } if codexResult.NormalizedModel != "" { - mappedModel = codexResult.NormalizedModel + upstreamModel = codexResult.NormalizedModel } if codexResult.PromptCacheKey != "" { promptCacheKey = codexResult.PromptCacheKey @@ -1981,7 +1981,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", account.ID, account.Type, - mappedModel, + upstreamModel, reqStream, hasPreviousResponseID, ) @@ -2070,7 +2070,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco isCodexCLI, reqStream, originalModel, - mappedModel, + upstreamModel, startTime, attempt, wsLastFailureReason, @@ -2171,7 +2171,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) - wsResult.UpstreamModel = mappedModel + wsResult.UpstreamModel = upstreamModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2276,14 +2276,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var usage *OpenAIUsage var firstTokenMs *int if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } @@ -2307,7 +2307,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, - UpstreamModel: mappedModel, + UpstreamModel: upstreamModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, diff --git a/backend/internal/service/openai_model_mapping.go b/backend/internal/service/openai_model_mapping.go index 9bf3fba3..4f8c094b 100644 --- a/backend/internal/service/openai_model_mapping.go +++ b/backend/internal/service/openai_model_mapping.go @@ -1,8 +1,10 @@ package service -// resolveOpenAIForwardModel determines the upstream model for OpenAI-compatible -// forwarding. Group-level default mapping only applies when the account itself -// did not match any explicit model_mapping rule. +import "strings" + +// resolveOpenAIForwardModel resolves the account/group mapping result for +// OpenAI-compatible forwarding. Group-level default mapping only applies when +// the account itself did not match any explicit model_mapping rule. func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedModel string) string { if account == nil { if defaultMappedModel != "" { @@ -17,3 +19,23 @@ func resolveOpenAIForwardModel(account *Account, requestedModel, defaultMappedMo } return mappedModel } + +func resolveOpenAIUpstreamModel(model string) string { + if isBareGPT53CodexSparkModel(model) { + return "gpt-5.3-codex-spark" + } + return normalizeCodexModel(strings.TrimSpace(model)) +} + +func isBareGPT53CodexSparkModel(model string) bool { + modelID := strings.TrimSpace(model) + if modelID == "" { + return false + } + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + normalized := strings.ToLower(strings.TrimSpace(modelID)) + return normalized == "gpt-5.3-codex-spark" || normalized == "gpt 5.3 codex spark" +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index edbb968b..42f58b37 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -74,13 +74,30 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * Credentials: map[string]any{}, } - withoutDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "") - if got := normalizeCodexModel(withoutDefault); got != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withoutDefault, got, "gpt-5.1") + withoutDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) + if withoutDefault != "gpt-5.1" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withoutDefault, "gpt-5.1") } - withDefault := resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4") - if got := normalizeCodexModel(withDefault); got != "gpt-5.4" { - t.Fatalf("normalizeCodexModel(%q) = %q, want %q", withDefault, got, "gpt-5.4") + withDefault := resolveOpenAIUpstreamModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) + if withDefault != "gpt-5.4" { + t.Fatalf("resolveOpenAIUpstreamModel(...) = %q, want %q", withDefault, "gpt-5.4") + } +} + +func TestResolveOpenAIUpstreamModel(t *testing.T) { + cases := map[string]string{ + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + " openai/gpt-5.3-codex-spark ": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + } + + for input, expected := range cases { + if got := resolveOpenAIUpstreamModel(input); got != expected { + t.Fatalf("resolveOpenAIUpstreamModel(%q) = %q, want %q", input, got, expected) + } } } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 4f1837c4..1ebe5542 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -2515,12 +2515,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } - mappedModel := account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } - if mappedModel != originalModel { - next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + upstreamModel := resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) + if upstreamModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", upstreamModel) if setErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) } @@ -2776,10 +2773,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModel := "" var mappedModelBytes []byte if originalModel != "" { - mappedModel = account.GetMappedModel(originalModel) - if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { - mappedModel = normalizedModel - } + mappedModel = resolveOpenAIUpstreamModel(account.GetMappedModel(originalModel)) needModelReplace = mappedModel != "" && mappedModel != originalModel if needModelReplace { mappedModelBytes = []byte(mappedModel)