diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 457309d3..a68c9b67 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -187,9 +187,13 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } func normalizeCodexModel(model string) string { + model = strings.TrimSpace(model) if model == "" { return "gpt-5.4" } + if isOpenAIImageGenerationModel(model) { + return model + } modelID := model if strings.Contains(modelID, "/") { @@ -231,6 +235,78 @@ func normalizeCodexModel(model string) string { return "gpt-5.4" } +func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + return true + } + } + return false +} + +func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + + modified := false + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" { + continue + } + if _, ok := toolMap["output_format"]; !ok { + if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" { + toolMap["output_format"] = value + modified = true + } + } + if _, ok := toolMap["output_compression"]; !ok { + if value, exists := toolMap["compression"]; exists && value != nil { + toolMap["output_compression"] = value + modified = true + } + } + if _, ok := toolMap["format"]; ok { + delete(toolMap, "format") + modified = true + } + if _, ok := toolMap["compression"]; ok { + delete(toolMap, "compression") + modified = true + } + } + return modified +} + +func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { + if !hasOpenAIImageGenerationTool(reqBody) { + return nil + } + model = strings.TrimSpace(model) + if !isOpenAIImageGenerationModel(model) { + return nil + } + return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model) +} + func normalizeOpenAIModelForUpstream(account *Account, model string) string { if account == nil || account.Type == AccountTypeOAuth { return normalizeCodexModel(model) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 22264f5e..f08e4b15 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -217,6 +217,42 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction require.Equal(t, "bash", first["name"]) } +func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) { + reqBody := map[string]any{ + "tools": []any{ + map[string]any{ + "type": "image_generation", + "format": "png", + "compression": 60, + }, + }, + } + + modified := normalizeOpenAIResponsesImageGenerationTools(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "png", first["output_format"]) + require.Equal(t, 60, first["output_compression"]) + _, hasFormat := first["format"] + require.False(t, hasFormat) + _, hasCompression := first["compression"] + require.False(t, hasCompression) +} + +func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) { + err := validateOpenAIResponsesImageModel(map[string]any{ + "tools": []any{ + map[string]any{"type": "image_generation"}, + }, + }, "gpt-image-2") + + require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`) +} + func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index a4a7ff1b..534ffeee 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1935,6 +1935,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } + if normalizeOpenAIResponsesImageGenerationTools(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") + } + // 对所有请求执行模型映射(包含 Codex CLI)。 billingModel := account.GetMappedModel(reqModel) if billingModel != reqModel { @@ -1944,6 +1950,26 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel + if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "model", + }, + }) + return nil, err + } + if hasOpenAIImageGenerationTool(reqBody) { + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s", + reqModel, + upstreamModel, + account.Type, + ) + } // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index fb6bdc7f..f11a2278 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -47,6 +47,7 @@ const ( openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" openAIImageRequirementsDiff = "0fffff" + openAIImageLifecycleTimeout = 2 * time.Minute ) type OpenAIImagesCapability string @@ -148,6 +149,9 @@ func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []b } applyOpenAIImagesDefaults(req) + if err := validateOpenAIImagesModel(req.Model); err != nil { + return nil, err + } req.SizeTier = normalizeOpenAIImageSizeTier(req.Size) req.RequiredCapability = classifyOpenAIImagesCapability(req) return req, nil @@ -295,6 +299,21 @@ func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) { req.Model = "gpt-image-2" } +func isOpenAIImageGenerationModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-") +} + +func validateOpenAIImagesModel(model string) error { + model = strings.TrimSpace(model) + if isOpenAIImageGenerationModel(model) { + return nil + } + if model == "" { + return fmt.Errorf("images endpoint requires an image model") + } + return fmt.Errorf("images endpoint requires an image model, got %q", model) +} + func normalizeOpenAIImagesEndpointPath(path string) string { trimmed := strings.TrimSpace(path) switch { @@ -400,7 +419,21 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { requestModel = mapped } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } upstreamModel := account.GetMappedModel(requestModel) + if err := validateOpenAIImagesModel(upstreamModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s", + strings.TrimSpace(parsed.Model), + upstreamModel, + parsed.Endpoint, + account.Type, + ) forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel) if err != nil { return nil, err @@ -759,6 +792,17 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { requestModel = mapped } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", + requestModel, + parsed.Endpoint, + account.Type, + len(parsed.Uploads), + ) token, _, err := s.GetAccessToken(ctx, account) if err != nil { @@ -844,8 +888,18 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( return nil, err } pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d", + conversationID, + len(pointerInfos), + countOpenAIFileServicePointerInfos(pointerInfos), + countOpenAIDirectImageAssets(pointerInfos), + ) + lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout) + defer releaseLifecycleCtx() if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) + polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID) if pollErr != nil { return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr) } @@ -853,10 +907,11 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( } pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) if len(pointerInfos) == 0 { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID) return nil, fmt.Errorf("openai image conversation returned no downloadable images") } - responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos) + responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos) if err != nil { return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) } @@ -1283,8 +1338,11 @@ func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMess } type openAIImagePointerInfo struct { - Pointer string - Prompt string + Pointer string + DownloadURL string + B64JSON string + MimeType string + Prompt string } type openAIImageToolMessage struct { @@ -1336,10 +1394,6 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { if len(body) == 0 { return nil } - matches := openAIImagePointerMatches(body) - if len(matches) == 0 { - return nil - } prompt := "" for _, path := range []string{ "message.metadata.dalle.prompt", @@ -1351,11 +1405,12 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { break } } + matches := openAIImagePointerMatches(body) out := make([]openAIImagePointerInfo, 0, len(matches)) for _, pointer := range matches { out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt}) } - return out + return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt)) } func openAIImagePointerMatches(body []byte) []string { @@ -1394,27 +1449,72 @@ func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []open seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next)) out := make([]openAIImagePointerInfo, 0, len(existing)+len(next)) for _, item := range existing { - seen[item.Pointer] = item + if key := item.identityKey(); key != "" { + seen[key] = item + } out = append(out, item) } for _, item := range next { - if existingItem, ok := seen[item.Pointer]; ok { - if existingItem.Prompt == "" && item.Prompt != "" { + key := item.identityKey() + if key == "" { + continue + } + if existingItem, ok := seen[key]; ok { + merged := mergeOpenAIImagePointerInfo(existingItem, item) + if merged != existingItem { for i := range out { - if out[i].Pointer == item.Pointer { - out[i].Prompt = item.Prompt + if out[i].identityKey() == key { + out[i] = merged break } } + seen[key] = merged } continue } - seen[item.Pointer] = item + seen[key] = item out = append(out, item) } return out } +func (i openAIImagePointerInfo) identityKey() string { + switch { + case strings.TrimSpace(i.Pointer) != "": + return "pointer:" + strings.TrimSpace(i.Pointer) + case strings.TrimSpace(i.DownloadURL) != "": + return "download:" + strings.TrimSpace(i.DownloadURL) + case strings.TrimSpace(i.B64JSON) != "": + b64 := strings.TrimSpace(i.B64JSON) + if len(b64) > 64 { + b64 = b64[:64] + } + return "b64:" + b64 + default: + return "" + } +} + +func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo { + merged := existing + if strings.TrimSpace(merged.Pointer) == "" { + merged.Pointer = next.Pointer + } + if strings.TrimSpace(merged.DownloadURL) == "" { + merged.DownloadURL = next.DownloadURL + } + if strings.TrimSpace(merged.B64JSON) == "" { + merged.B64JSON = next.B64JSON + } + if strings.TrimSpace(merged.MimeType) == "" { + merged.MimeType = next.MimeType + } + if strings.TrimSpace(merged.Prompt) == "" { + merged.Prompt = next.Prompt + } + return merged +} + func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { for _, item := range items { if strings.HasPrefix(item.Pointer, "file-service://") { @@ -1424,6 +1524,26 @@ func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { return false } +func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int { + count := 0 + for _, item := range items { + if strings.HasPrefix(item.Pointer, "file-service://") { + count++ + } + } + return count +} + +func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int { + count := 0 + for _, item := range items { + if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" { + count++ + } + } + return count +} + func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo { if !hasOpenAIFileServicePointerInfos(items) { return items @@ -1591,11 +1711,7 @@ func buildOpenAIImageResponse( } items := make([]responseItem, 0, len(pointers)) for _, pointer := range pointers { - downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) - if err != nil { - return nil, 0, err - } - data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer) if err != nil { return nil, 0, err } @@ -1615,6 +1731,136 @@ func buildOpenAIImageResponse( return body, len(items), nil } +func resolveOpenAIImageBytes( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointer openAIImagePointerInfo, +) ([]byte, error) { + if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" { + return base64.StdEncoding.DecodeString(normalized) + } + if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" { + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + } + if strings.TrimSpace(pointer.Pointer) == "" { + return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data") + } + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return nil, err + } + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) +} + +func normalizeOpenAIImageBase64(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(raw), "data:") { + if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) { + raw = raw[idx+1:] + } + } + raw = strings.TrimSpace(raw) + raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4) + if raw == "" { + return "" + } + if _, err := base64.StdEncoding.DecodeString(raw); err != nil { + return "" + } + return raw +} + +func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo { + if len(body) == 0 || !gjson.ValidBytes(body) { + return nil + } + var decoded any + if err := json.Unmarshal(body, &decoded); err != nil { + return nil + } + var out []openAIImagePointerInfo + walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out) + return out +} + +func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) { + switch value := node.(type) { + case map[string]any: + localPrompt := prompt + for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} { + if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" { + localPrompt = strings.TrimSpace(v) + break + } + } + item := openAIImagePointerInfo{ + Prompt: localPrompt, + Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]), + DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]), + B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]), + MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]), + } + switch { + case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"), + strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"), + isLikelyOpenAIImageDownloadURL(item.DownloadURL), + normalizeOpenAIImageBase64(item.B64JSON) != "": + *out = append(*out, item) + } + for _, child := range value { + walkOpenAIImageInlineAssets(child, localPrompt, out) + } + case []any: + for _, child := range value { + walkOpenAIImageInlineAssets(child, prompt, out) + } + } +} + +func firstNonEmptyString(values ...any) string { + for _, value := range values { + if s, ok := value.(string); ok && strings.TrimSpace(s) != "" { + return strings.TrimSpace(s) + } + } + return "" +} + +func isLikelyOpenAIImageDownloadURL(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "data:image/") { + return true + } + if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") { + return false + } + lower := strings.ToLower(raw) + return strings.Contains(lower, "/download") || + strings.Contains(lower, ".png") || + strings.Contains(lower, ".jpg") || + strings.Contains(lower, ".jpeg") || + strings.Contains(lower, ".webp") +} + +func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + if timeout <= 0 { + return base, func() {} + } + return context.WithTimeout(base, timeout) +} + func fetchOpenAIImageDownloadURL( ctx context.Context, client *req.Client, diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 173d69ba..6aa1d5e5 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "context" "mime/multipart" "net/http" "net/http/httptest" @@ -103,3 +104,56 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNative require.NotNil(t, parsed) require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.Nil(t, parsed) + require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`) +} + +func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) { + items := collectOpenAIImagePointers([]byte(`{ + "revised_prompt": "cat astronaut", + "parts": [ + {"b64_json":"QUJD"}, + {"download_url":"https://files.example.com/image.png?sig=1"}, + {"asset_pointer":"file-service://file_123"} + ] + }`)) + + require.Len(t, items, 3) + var sawBase64, sawURL, sawPointer bool + for _, item := range items { + if item.B64JSON == "QUJD" { + sawBase64 = true + require.Equal(t, "cat astronaut", item.Prompt) + } + if item.DownloadURL == "https://files.example.com/image.png?sig=1" { + sawURL = true + } + if item.Pointer == "file-service://file_123" { + sawPointer = true + } + } + require.True(t, sawBase64) + require.True(t, sawURL) + require.True(t, sawPointer) +} + +func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) { + data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{ + B64JSON: "data:image/png;base64,QUJD", + }) + require.NoError(t, err) + require.Equal(t, []byte("ABC"), data) +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 35e7c250..f25863a8 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -91,6 +91,7 @@ func TestNormalizeCodexModel(t *testing.T) { "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", + "gpt-image-2": "gpt-image-2", } for input, expected := range cases { diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 2bf48702..106ec9f7 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -812,6 +812,16 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { return openAIGPT54FallbackPricing } + if isOpenAIImageGenerationModel(model) { + for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} { + if pricing, ok := s.pricingData[candidate]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate) + return pricing + } + } + return nil + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 13a5c70c..e2bd7cf3 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -128,6 +128,21 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t require.Zero(t, got.LongContextInputTokenThreshold) } +func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) { + imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3} + textPricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-image-2": imagePricing, + "gpt-5.4": textPricing, + }, + } + + got := svc.GetModelPricing("gpt-image-3") + require.Same(t, imagePricing, got) +} + func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { raw := map[string]any{ "gpt-5.4": map[string]any{