fix: normalize chat completions service tier
This commit is contained in:
@@ -107,11 +107,15 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
responsesBody = stripped
|
responsesBody = stripped
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
responsesBody, normalizedServiceTier, err := normalizeResponsesBodyServiceTier(responsesBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("normalize service_tier in responses-shape body: %w", err)
|
||||||
|
}
|
||||||
// Minimal stub populated from the raw body so downstream billing
|
// Minimal stub populated from the raw body so downstream billing
|
||||||
// propagation (ServiceTier, ReasoningEffort) keeps working.
|
// propagation (ServiceTier, ReasoningEffort) keeps working.
|
||||||
responsesReq = &apicompat.ResponsesRequest{
|
responsesReq = &apicompat.ResponsesRequest{
|
||||||
Model: upstreamModel,
|
Model: upstreamModel,
|
||||||
ServiceTier: gjson.GetBytes(responsesBody, "service_tier").String(),
|
ServiceTier: normalizedServiceTier,
|
||||||
}
|
}
|
||||||
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
|
if effort := gjson.GetBytes(responsesBody, "reasoning.effort").String(); effort != "" {
|
||||||
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
|
responsesReq.Reasoning = &apicompat.ResponsesReasoning{Effort: effort}
|
||||||
@@ -124,6 +128,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
return nil, fmt.Errorf("convert chat completions to responses: %w", err)
|
||||||
}
|
}
|
||||||
responsesReq.Model = upstreamModel
|
responsesReq.Model = upstreamModel
|
||||||
|
normalizeResponsesRequestServiceTier(responsesReq)
|
||||||
responsesBody, err = json.Marshal(responsesReq)
|
responsesBody, err = json.Marshal(responsesReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("marshal responses request: %w", err)
|
return nil, fmt.Errorf("marshal responses request: %w", err)
|
||||||
@@ -274,6 +279,41 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
|||||||
return result, handleErr
|
return result, handleErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesRequestServiceTier(req *apicompat.ResponsesRequest) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.ServiceTier = normalizedOpenAIServiceTierValue(req.ServiceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesBodyServiceTier(body []byte) ([]byte, string, error) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, "", nil
|
||||||
|
}
|
||||||
|
rawServiceTier := gjson.GetBytes(body, "service_tier").String()
|
||||||
|
if rawServiceTier == "" {
|
||||||
|
return body, "", nil
|
||||||
|
}
|
||||||
|
normalizedServiceTier := normalizedOpenAIServiceTierValue(rawServiceTier)
|
||||||
|
if normalizedServiceTier == "" {
|
||||||
|
trimmed, err := sjson.DeleteBytes(body, "service_tier")
|
||||||
|
return trimmed, "", err
|
||||||
|
}
|
||||||
|
if normalizedServiceTier == rawServiceTier {
|
||||||
|
return body, normalizedServiceTier, nil
|
||||||
|
}
|
||||||
|
trimmed, err := sjson.SetBytes(body, "service_tier", normalizedServiceTier)
|
||||||
|
return trimmed, normalizedServiceTier, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizedOpenAIServiceTierValue(raw string) string {
|
||||||
|
normalized := normalizeOpenAIServiceTier(raw)
|
||||||
|
if normalized == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *normalized
|
||||||
|
}
|
||||||
|
|
||||||
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
// handleChatCompletionsErrorResponse reads an upstream error and returns it in
|
||||||
// OpenAI Chat Completions error format.
|
// OpenAI Chat Completions error format.
|
||||||
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
func (s *OpenAIGatewayService) handleChatCompletionsErrorResponse(
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeResponsesRequestServiceTier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := &apicompat.ResponsesRequest{ServiceTier: " fast "}
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Equal(t, "priority", req.ServiceTier)
|
||||||
|
|
||||||
|
req.ServiceTier = "flex"
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Equal(t, "flex", req.ServiceTier)
|
||||||
|
|
||||||
|
req.ServiceTier = "default"
|
||||||
|
normalizeResponsesRequestServiceTier(req)
|
||||||
|
require.Empty(t, req.ServiceTier)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesBodyServiceTier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
body, tier, err := normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"fast"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "priority", tier)
|
||||||
|
require.Equal(t, "priority", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"flex"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "flex", tier)
|
||||||
|
require.Equal(t, "flex", gjson.GetBytes(body, "service_tier").String())
|
||||||
|
|
||||||
|
body, tier, err = normalizeResponsesBodyServiceTier([]byte(`{"model":"gpt-5.1","service_tier":"default"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, tier)
|
||||||
|
require.False(t, gjson.GetBytes(body, "service_tier").Exists())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user