diff --git a/backend/internal/service/gateway_forward_as_chat_completions.go b/backend/internal/service/gateway_forward_as_chat_completions.go index 6913447b..d3c611e2 100644 --- a/backend/internal/service/gateway_forward_as_chat_completions.go +++ b/backend/internal/service/gateway_forward_as_chat_completions.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" "go.uber.org/zap" ) @@ -171,19 +172,40 @@ func (s *GatewayService) ForwardAsChatCompletions( return nil, fmt.Errorf("upstream error: %d %s", resp.StatusCode, upstreamMsg) } - // 13. Handle normal response + // 13. Extract reasoning effort from CC request body + reasoningEffort := extractCCReasoningEffortFromBody(body) + + // 14. Handle normal response // Read Anthropic SSE → convert to Responses events → convert to CC format var result *ForwardResult var handleErr error if clientStream { - result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, startTime, includeUsage) + result, handleErr = s.handleCCStreamingFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime, includeUsage) } else { - result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleCCBufferedFromAnthropic(resp, c, originalModel, mappedModel, reasoningEffort, startTime) } return result, handleErr } +// extractCCReasoningEffortFromBody reads reasoning effort from a Chat Completions +// request body. It checks both nested (reasoning.effort) and flat (reasoning_effort) +// formats used by OpenAI-compatible clients. +func extractCCReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + raw = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + // handleCCBufferedFromAnthropic reads Anthropic SSE events, assembles the full // response, then converts Anthropic → Responses → Chat Completions. func (s *GatewayService) handleCCBufferedFromAnthropic( @@ -191,6 +213,7 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( c *gin.Context, originalModel string, mappedModel string, + reasoningEffort *string, startTime time.Time, ) (*ForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -225,18 +248,16 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( continue } + // message_start carries the initial response structure and cache usage if event.Type == "message_start" && event.Message != nil { finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) } + + // message_delta carries final usage and stop_reason if event.Type == "message_delta" { if event.Usage != nil { - usage = ClaudeUsage{ - InputTokens: event.Usage.InputTokens, - OutputTokens: event.Usage.OutputTokens, - } - if event.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens - } + mergeAnthropicUsage(&usage, *event.Usage) } if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { finalResp.StopReason = event.Delta.StopReason @@ -274,10 +295,13 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( return nil, fmt.Errorf("upstream stream ended without response") } + // Update usage from accumulated delta if usage.InputTokens > 0 || usage.OutputTokens > 0 { finalResp.Usage = apicompat.AnthropicUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, } } @@ -291,12 +315,13 @@ func (s *GatewayService) handleCCBufferedFromAnthropic( c.JSON(http.StatusOK, ccResp) return &ForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -307,6 +332,7 @@ func (s *GatewayService) handleCCStreamingFromAnthropic( c *gin.Context, originalModel string, mappedModel string, + reasoningEffort *string, startTime time.Time, includeUsage bool, ) (*ForwardResult, error) { @@ -341,13 +367,14 @@ func (s *GatewayService) handleCCStreamingFromAnthropic( resultWithUsage := func() *ForwardResult { return &ForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } @@ -369,18 +396,13 @@ func (s *GatewayService) handleCCStreamingFromAnthropic( firstTokenMs = &ms } - // Extract usage + // Extract usage from message_delta if event.Type == "message_delta" && event.Usage != nil { - usage = ClaudeUsage{ - InputTokens: event.Usage.InputTokens, - OutputTokens: event.Usage.OutputTokens, - } - if event.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens - } + mergeAnthropicUsage(&usage, *event.Usage) } - if event.Type == "message_start" && event.Message != nil && event.Message.Usage.InputTokens > 0 { - usage.InputTokens = event.Message.Usage.InputTokens + // Also capture usage from message_start (carries cache fields) + if event.Type == "message_start" && event.Message != nil { + mergeAnthropicUsage(&usage, event.Message.Usage) } // Chain: Anthropic event → Responses events → CC chunks diff --git a/backend/internal/service/gateway_forward_as_chat_completions_test.go b/backend/internal/service/gateway_forward_as_chat_completions_test.go new file mode 100644 index 00000000..5003e5b3 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_chat_completions_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractCCReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + t.Run("nested reasoning.effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + }) + + t.Run("flat reasoning_effort", func(t *testing.T) { + got := extractCCReasoningEffortFromBody([]byte(`{"reasoning_effort":"x-high"}`)) + require.NotNil(t, got) + require.Equal(t, "xhigh", *got) + }) + + t.Run("missing effort", func(t *testing.T) { + require.Nil(t, extractCCReasoningEffortFromBody([]byte(`{"model":"gpt-5"}`))) + }) +} + +func TestHandleCCBufferedFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "high" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCBufferedFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "high", *result.ReasoningEffort) +} + +func TestHandleCCStreamingFromAnthropic_PreservesMessageStartCacheUsageAndReasoning(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + reasoningEffort := "medium" + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_cc_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleCCStreamingFromAnthropic(resp, c, "gpt-5", "claude-sonnet-4.5", &reasoningEffort, time.Now(), true) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "medium", *result.ReasoningEffort) + require.Contains(t, rec.Body.String(), `[DONE]`) +} diff --git a/backend/internal/service/gateway_forward_as_responses.go b/backend/internal/service/gateway_forward_as_responses.go index cf562f39..5dca57f9 100644 --- a/backend/internal/service/gateway_forward_as_responses.go +++ b/backend/internal/service/gateway_forward_as_responses.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" "go.uber.org/zap" ) @@ -56,6 +57,7 @@ func (s *GatewayService) ForwardAsResponses( // 4. Model mapping mappedModel := originalModel + reasoningEffort := ExtractResponsesReasoningEffortFromBody(body) if account.Type == AccountTypeAPIKey { mappedModel = account.GetMappedModel(originalModel) } @@ -172,14 +174,46 @@ func (s *GatewayService) ForwardAsResponses( var result *ForwardResult var handleErr error if clientStream { - result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleResponsesStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) } else { - result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, startTime) + result, handleErr = s.handleResponsesBufferedStreamingResponse(resp, c, originalModel, mappedModel, reasoningEffort, startTime) } return result, handleErr } +// ExtractResponsesReasoningEffortFromBody reads Responses API reasoning.effort +// and normalizes it for usage logging. +func ExtractResponsesReasoningEffortFromBody(body []byte) *string { + raw := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if raw == "" { + return nil + } + normalized := normalizeOpenAIReasoningEffort(raw) + if normalized == "" { + return nil + } + return &normalized +} + +func mergeAnthropicUsage(dst *ClaudeUsage, src apicompat.AnthropicUsage) { + if dst == nil { + return + } + if src.InputTokens > 0 { + dst.InputTokens = src.InputTokens + } + if src.OutputTokens > 0 { + dst.OutputTokens = src.OutputTokens + } + if src.CacheReadInputTokens > 0 { + dst.CacheReadInputTokens = src.CacheReadInputTokens + } + if src.CacheCreationInputTokens > 0 { + dst.CacheCreationInputTokens = src.CacheCreationInputTokens + } +} + // handleResponsesBufferedStreamingResponse reads all Anthropic SSE events from // the upstream streaming response, assembles them into a complete Anthropic // response, converts to Responses API JSON format, and writes it to the client. @@ -188,6 +222,7 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse( c *gin.Context, originalModel string, mappedModel string, + reasoningEffort *string, startTime time.Time, ) (*ForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -233,21 +268,13 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse( // message_start carries the initial response structure if event.Type == "message_start" && event.Message != nil { finalResp = event.Message + mergeAnthropicUsage(&usage, event.Message.Usage) } // message_delta carries final usage and stop_reason if event.Type == "message_delta" { if event.Usage != nil { - usage = ClaudeUsage{ - InputTokens: event.Usage.InputTokens, - OutputTokens: event.Usage.OutputTokens, - } - if event.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens - } - if event.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens - } + mergeAnthropicUsage(&usage, *event.Usage) } if event.Delta != nil && event.Delta.StopReason != "" && finalResp != nil { finalResp.StopReason = event.Delta.StopReason @@ -307,12 +334,13 @@ func (s *GatewayService) handleResponsesBufferedStreamingResponse( c.JSON(http.StatusOK, responsesResp) return &ForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -323,6 +351,7 @@ func (s *GatewayService) handleResponsesStreamingResponse( c *gin.Context, originalModel string, mappedModel string, + reasoningEffort *string, startTime time.Time, ) (*ForwardResult, error) { requestID := resp.Header.Get("x-request-id") @@ -351,13 +380,14 @@ func (s *GatewayService) handleResponsesStreamingResponse( resultWithUsage := func() *ForwardResult { return &ForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + UpstreamModel: mappedModel, + ReasoningEffort: reasoningEffort, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } @@ -371,22 +401,11 @@ func (s *GatewayService) handleResponsesStreamingResponse( // Extract usage from message_delta if event.Type == "message_delta" && event.Usage != nil { - usage = ClaudeUsage{ - InputTokens: event.Usage.InputTokens, - OutputTokens: event.Usage.OutputTokens, - } - if event.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = event.Usage.CacheReadInputTokens - } - if event.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = event.Usage.CacheCreationInputTokens - } + mergeAnthropicUsage(&usage, *event.Usage) } // Also capture usage from message_start if event.Type == "message_start" && event.Message != nil { - if event.Message.Usage.InputTokens > 0 { - usage.InputTokens = event.Message.Usage.InputTokens - } + mergeAnthropicUsage(&usage, event.Message.Usage) } // Convert to Responses events diff --git a/backend/internal/service/gateway_forward_as_responses_test.go b/backend/internal/service/gateway_forward_as_responses_test.go new file mode 100644 index 00000000..e48d8b22 --- /dev/null +++ b/backend/internal/service/gateway_forward_as_responses_test.go @@ -0,0 +1,94 @@ +//go:build unit + +package service + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractResponsesReasoningEffortFromBody(t *testing.T) { + t.Parallel() + + got := ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5","reasoning":{"effort":"HIGH"}}`)) + require.NotNil(t, got) + require.Equal(t, "high", *got) + + require.Nil(t, ExtractResponsesReasoningEffortFromBody([]byte(`{"model":"claude-sonnet-4.5"}`))) +} + +func TestHandleResponsesBufferedStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_buffered"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":12,"cache_read_input_tokens":9,"cache_creation_input_tokens":3}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":7}}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesBufferedStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 9, result.Usage.CacheReadInputTokens) + require.Equal(t, 3, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `"cached_tokens":9`) +} + +func TestHandleResponsesStreamingResponse_PreservesMessageStartCacheUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + resp := &http.Response{ + Header: http.Header{"x-request-id": []string{"rid_stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `event: message_start`, + `data: {"type":"message_start","message":{"id":"msg_2","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4.5","stop_reason":"","usage":{"input_tokens":20,"cache_read_input_tokens":11,"cache_creation_input_tokens":4}}}`, + ``, + `event: content_block_start`, + `data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":"hello"}}`, + ``, + `event: message_delta`, + `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":8}}`, + ``, + `event: message_stop`, + `data: {"type":"message_stop"}`, + ``, + }, "\n"))), + } + + svc := &GatewayService{} + result, err := svc.handleResponsesStreamingResponse(resp, c, "claude-sonnet-4.5", "claude-sonnet-4.5", nil, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 20, result.Usage.InputTokens) + require.Equal(t, 8, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, 4, result.Usage.CacheCreationInputTokens) + require.Contains(t, rec.Body.String(), `response.completed`) +}