diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index ac7d28a7..663066a3 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( responsesBody = stripped } } + responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody) + if err != nil { + return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err) + } // Minimal stub populated from the raw body so downstream billing // propagation (ServiceTier, ReasoningEffort) keeps working. responsesReq = &apicompat.ResponsesRequest{ Model: upstreamModel, - ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(), + ServiceTier: normalizedServiceTier, } if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" { responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort} @@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("convert chat completions to responses: %w", err) } responsesReq.Model = upstreamModel + normalizeResponsesRequestServiceTier(responsesReq) responsesBody, err = json.Marshal(responsesReq) if err != nil { return nil, fmt.Errorf("marshal responses request: %w", err) @@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( return result, handleErr } +func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) { + if req == nil { + return + } + req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier) +} + +func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) { + if len(body) == 0 { + return body, "", nil + } + rawServiceTier := gjson.GetBytes(body, "service_tier").String() + if rawServiceTier == "" { + return body, "", nil + } + normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier) + if normalizedServiceTier == "" { + trimmed, err := sjson.DeleteBytes(body, "service_tier") + return trimmed, "", err + } + if normalizedServiceTier == rawServiceTier { + return body, normalizedServiceTier, nil + } + trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier) + return trimmed, normalizedServiceTier, err +} + +func normalizedOpenAIServiceTierValue(raw string) string { + normalized := normalizeOpenAIServiceTier(raw) + if normalized == nil { + return "" + } + return *normalized +} + // handleChatCompletionsErrorResponse reads an upstream error and returns it in // OpenAI Chat Completions error format. func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse( diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go new file mode 100644 index 00000000..a00fb71c --- /dev/null +++ b/backend/internal/service/openai_gateway_chat_completions_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeResponsesRequestServiceTier(t *testing.T) { + t.Parallel() + + req := &apicompat.ResponsesRequest{ServiceTier: " fast "} + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "priority", req.ServiceTier) + + req.ServiceTier = "flex" + normalizeResponsesRequestServiceTier(req) + require.Equal(t, "flex", req.ServiceTier) + + req.ServiceTier = "default" + normalizeResponsesRequestServiceTier(req) + require.Empty(t, req.ServiceTier) +} + +func TestNormalizeResponsesBodyServiceTier(t *testing.T) { + t.Parallel() + + body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`)) + require.NoError(t, err) + require.Equal(t, "priority", tier) + require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`)) + require.NoError(t, err) + require.Equal(t, "flex", tier) + require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String()) + + body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`)) + require.NoError(t, err) + require.Empty(t, tier) + require.False(t, gjson.GetBytes(body, "service_tier").Exists()) +}