fix: normalize chat completions service tier

This commit is contained in:
IanShaw027
2026-04-21 13:56:02 +08:00
parent 0fcddce69e
commit 62ff2d803f
2 changed files with 85 additions and 1 deletions

View File

@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
responsesBody = stripped
}
}
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
if err != nil {
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
}
// Minimal stub populated from the raw body so downstream billing
// propagation (ServiceTier, ReasoningEffort) keeps working.
responsesReq = &apicompat.ResponsesRequest{
Model: upstreamModel,
ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
ServiceTier: normalizedServiceTier,
}
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
}
responsesReq.Model = upstreamModel
normalizeResponsesRequestServiceTier(responsesReq)
responsesBody, err = json.Marshal(responsesReq)
if err != nil {
return nil, fmt.Errorf("marshal responses request: %w", err)
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
return result, handleErr
}
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
if req == nil {
return
}
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
}
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
if len(body) == 0 {
return body, "", nil
}
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
if rawServiceTier == "" {
return body, "", nil
}
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
if normalizedServiceTier == "" {
trimmed, err := sjson.DeleteBytes(body, "service_tier")
return trimmed, "", err
}
if normalizedServiceTier == rawServiceTier {
return body, normalizedServiceTier, nil
}
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
return trimmed, normalizedServiceTier, err
}
func normalizedOpenAIServiceTierValue(raw string) string {
normalized := normalizeOpenAIServiceTier(raw)
if normalized == nil {
return ""
}
return *normalized
}
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
// OpenAI Chat Completions error format.
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(

View File

@@ -0,0 +1,44 @@
package service
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
t.Parallel()
req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "priority", req.ServiceTier)
req.ServiceTier = "flex"
normalizeResponsesRequestServiceTier(req)
require.Equal(t, "flex", req.ServiceTier)
req.ServiceTier = "default"
normalizeResponsesRequestServiceTier(req)
require.Empty(t, req.ServiceTier)
}
func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
t.Parallel()
body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
require.NoError(t, err)
require.Equal(t, "priority", tier)
require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
require.NoError(t, err)
require.Equal(t, "flex", tier)
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
require.NoError(t, err)
require.Empty(t, tier)
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
}