fix: preserve raw chat completions usage billing
This commit is contained in:
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -97,6 +98,13 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
return nil, policyErr
|
||||
}
|
||||
upstreamBody = updatedBody
|
||||
if clientStream {
|
||||
var usageErr error
|
||||
upstreamBody, usageErr = ensureOpenAIChatStreamUsage(upstreamBody)
|
||||
if usageErr != nil {
|
||||
return nil, fmt.Errorf("enable stream usage: %w", usageErr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.L().Debug("openai chat_completions raw: forwarding without protocol conversion",
|
||||
zap.Int64("account_id", account.ID),
|
||||
@@ -121,7 +129,9 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
}
|
||||
targetURL := buildOpenAIChatCompletionsURL(validatedURL)
|
||||
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
|
||||
upstreamReq, err := http.NewRequestWithContext(upstreamCtx, http.MethodPost, targetURL, bytes.NewReader(upstreamBody))
|
||||
releaseUpstreamCtx()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build upstream request: %w", err)
|
||||
}
|
||||
@@ -219,8 +229,8 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
// 末尾 [DONE] 之前的 chunk 中的 usage 字段,按 OpenAI CC 协议)。
|
||||
//
|
||||
// usage 字段仅在客户端请求 stream_options.include_usage=true 时出现于上游响应中。
|
||||
// 本函数不检查客户端的请求 flag——上游会自行处理,我们仅在上游响应
|
||||
// chunk 中出现 usage 时提取。
|
||||
// 网关会对上游强制打开 include_usage 以保证计费完整,并原样向下游透传 usage,
|
||||
// 让级联代理或下游计费系统也能拿到完整用量。
|
||||
func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
c *gin.Context,
|
||||
resp *http.Response,
|
||||
@@ -251,36 +261,41 @@ func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
|
||||
var usage OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
// Direct passthrough: write each line + blank line separator
|
||||
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
|
||||
logger.L().Debug("openai chat_completions raw: client write failed",
|
||||
zap.Error(werr),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
break
|
||||
if payload, ok := extractOpenAISSEDataLine(line); ok {
|
||||
trimmedPayload := strings.TrimSpace(payload)
|
||||
if trimmedPayload != "[DONE]" {
|
||||
usageOnlyChunk := isOpenAIChatUsageOnlyStreamChunk(payload)
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
if firstTokenMs == nil && !usageOnlyChunk {
|
||||
elapsed := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &elapsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
|
||||
clientDisconnected = true
|
||||
logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
|
||||
zap.Error(werr),
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
c.Writer.Flush()
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.Writer.Flush()
|
||||
|
||||
// Track first token timing on first non-empty data line
|
||||
if firstTokenMs == nil && strings.HasPrefix(line, "data: ") && line != "data: [DONE]" {
|
||||
elapsed := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &elapsed
|
||||
}
|
||||
|
||||
// Extract usage from any chunk that carries it (CC streams typically put
|
||||
// usage in the final chunk before [DONE], but may also appear elsewhere).
|
||||
if strings.HasPrefix(line, "data: ") && line != "data: [DONE]" {
|
||||
payload := line[6:]
|
||||
if u := extractCCStreamUsage(payload); u != nil {
|
||||
usage = *u
|
||||
}
|
||||
if !clientDisconnected {
|
||||
c.Writer.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,6 +322,45 @@ func (s *OpenAIGatewayService) streamRawChatCompletions(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ensureOpenAIChatStreamUsage 确保 raw Chat Completions 流式请求会让上游返回 usage。
|
||||
// usage 也会继续向下游透传,支持级联代理和下游计费系统。
|
||||
func ensureOpenAIChatStreamUsage(body []byte) ([]byte, error) {
|
||||
updated, err := sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func isOpenAIChatUsageOnlyStreamChunk(payload string) bool {
|
||||
if strings.TrimSpace(payload) == "" {
|
||||
return false
|
||||
}
|
||||
if !gjson.Get(payload, "usage").Exists() {
|
||||
return false
|
||||
}
|
||||
choices := gjson.Get(payload, "choices")
|
||||
return choices.Exists() && choices.IsArray() && len(choices.Array()) == 0
|
||||
}
|
||||
|
||||
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
|
||||
// CC 协议中 usage 仅出现在末尾 chunk(且仅当 include_usage 生效时),
|
||||
// 但上游可能在多个 chunk 中重复——总是用最新值。
|
||||
func extractCCStreamUsage(payload string) *OpenAIUsage {
|
||||
usageResult := gjson.Get(payload, "usage")
|
||||
if !usageResult.Exists() || !usageResult.IsObject() {
|
||||
return nil
|
||||
}
|
||||
u := OpenAIUsage{
|
||||
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
|
||||
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
|
||||
}
|
||||
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
u.CacheReadInputTokens = int(cached.Int())
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// bufferRawChatCompletions 透传上游 CC 非流式 JSON 响应。
|
||||
func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
c *gin.Context,
|
||||
@@ -320,9 +374,11 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
) (*OpenAIForwardResult, error) {
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 32<<20))
|
||||
respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
|
||||
if err != nil {
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||
if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
writeChatCompletionsError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response")
|
||||
}
|
||||
return nil, fmt.Errorf("read upstream body: %w", err)
|
||||
}
|
||||
|
||||
@@ -362,24 +418,6 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractCCStreamUsage 从单个 CC 流式 chunk 的 payload 中提取 usage 字段。
|
||||
// CC 协议中 usage 仅出现在末尾 chunk(且仅当客户端请求 stream_options.include_usage
|
||||
// 时),但上游可能在多个 chunk 中重复——总是用最新值。
|
||||
func extractCCStreamUsage(payload string) *OpenAIUsage {
|
||||
usageResult := gjson.Get(payload, "usage")
|
||||
if !usageResult.Exists() || !usageResult.IsObject() {
|
||||
return nil
|
||||
}
|
||||
u := OpenAIUsage{
|
||||
InputTokens: int(gjson.Get(payload, "usage.prompt_tokens").Int()),
|
||||
OutputTokens: int(gjson.Get(payload, "usage.completion_tokens").Int()),
|
||||
}
|
||||
if cached := gjson.Get(payload, "usage.prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
u.CacheReadInputTokens = int(cached.Int())
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// buildOpenAIChatCompletionsURL 拼接上游 Chat Completions 端点 URL。
|
||||
//
|
||||
// - base 已是 /chat/completions:原样返回
|
||||
|
||||
@@ -3,9 +3,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
|
||||
@@ -65,3 +75,186 @@ func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDownstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13,"prompt_tokens_details":{"cached_tokens":3}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_usage"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 9, result.Usage.InputTokens)
|
||||
require.Equal(t, 4, result.Usage.OutputTokens)
|
||||
require.Equal(t, 3, result.Usage.CacheReadInputTokens)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
require.Contains(t, rec.Body.String(), `"usage"`)
|
||||
require.Contains(t, rec.Body.String(), "data: [DONE]")
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Writer = &openAIChatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":17,"completion_tokens":8,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":6}}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_disconnect"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, 17, result.Usage.InputTokens)
|
||||
require.Equal(t, 8, result.Usage.OutputTokens)
|
||||
require.Equal(t, 6, result.Usage.CacheReadInputTokens)
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reqCtx, cancel := context.WithCancel(context.Background())
|
||||
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)).WithContext(reqCtx)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
cancel()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","model":"gpt-5.4","choices":[],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_raw_ctx"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: rawChatCompletionsTestConfig(),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
account := rawChatCompletionsTestAccount()
|
||||
|
||||
result, err := svc.forwardAsRawChatCompletions(reqCtx, c, account, body, "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.NoError(t, upstream.lastReq.Context().Err())
|
||||
}
|
||||
|
||||
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.True(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[{"index":0}],"usage":{"prompt_tokens":1,"completion_tokens":2}}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(`{"choices":[]}`))
|
||||
require.False(t, isOpenAIChatUsageOnlyStreamChunk(``))
|
||||
}
|
||||
|
||||
func TestEnsureOpenAIChatStreamUsage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, err := ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4"}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
|
||||
body, err = ensureOpenAIChatStreamUsage([]byte(`{"model":"gpt-5.4","stream_options":{"include_usage":false}}`))
|
||||
require.NoError(t, err)
|
||||
require.True(t, gjson.GetBytes(body, "stream_options.include_usage").Bool())
|
||||
}
|
||||
|
||||
func TestBufferRawChatCompletions_RejectsOversizedResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader("toolong")),
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
|
||||
svc.cfg.Gateway.UpstreamResponseReadMaxBytes = 3
|
||||
|
||||
result, err := svc.bufferRawChatCompletions(c, resp, "gpt-5.4", "gpt-5.4", "gpt-5.4", nil, nil, time.Now())
|
||||
require.ErrorIs(t, err, ErrUpstreamResponseBodyTooLarge)
|
||||
require.Nil(t, result)
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestConfig() *config.Config {
|
||||
return &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func rawChatCompletionsTestAccount() *Account {
|
||||
return &Account{
|
||||
ID: 101,
|
||||
Name: "raw-openai-apikey",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "http://upstream.example",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user