fix: normalize chat completions service tier
This commit is contained in:
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
responsesBody = stripped
|
||||
}
|
||||
}
|
||||
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
|
||||
}
|
||||
// Minimal stub populated from the raw body so downstream billing
|
||||
// propagation (ServiceTier, ReasoningEffort) keeps working.
|
||||
responsesReq = &apicompat.ResponsesRequest{
|
||||
Model: upstreamModel,
|
||||
ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
|
||||
ServiceTier: normalizedServiceTier,
|
||||
}
|
||||
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
|
||||
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
|
||||
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||
}
|
||||
responsesReq.Model = upstreamModel
|
||||
normalizeResponsesRequestServiceTier(responsesReq)
|
||||
responsesBody, err = json.Marshal(responsesReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||||
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
return result, handleErr
|
||||
}
|
||||
|
||||
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
|
||||
}
|
||||
|
||||
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
|
||||
if len(body) == 0 {
|
||||
return body, "", nil
|
||||
}
|
||||
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
|
||||
if rawServiceTier == "" {
|
||||
return body, "", nil
|
||||
}
|
||||
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
|
||||
if normalizedServiceTier == "" {
|
||||
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||
return trimmed, "", err
|
||||
}
|
||||
if normalizedServiceTier == rawServiceTier {
|
||||
return body, normalizedServiceTier, nil
|
||||
}
|
||||
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
|
||||
return trimmed, normalizedServiceTier, err
|
||||
}
|
||||
|
||||
func normalizedOpenAIServiceTierValue(raw string) string {
|
||||
normalized := normalizeOpenAIServiceTier(raw)
|
||||
if normalized == nil {
|
||||
return ""
|
||||
}
|
||||
return *normalized
|
||||
}
|
||||
|
||||
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||||
// OpenAI Chat Completions error format.
|
||||
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "priority", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "flex"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Equal(t, "flex", req.ServiceTier)
|
||||
|
||||
req.ServiceTier = "default"
|
||||
normalizeResponsesRequestServiceTier(req)
|
||||
require.Empty(t, req.ServiceTier)
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "priority", tier)
|
||||
require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "flex", tier)
|
||||
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||
|
||||
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, tier)
|
||||
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||
}
|
||||
Reference in New Issue
Block a user