Merge pull request #316 from cyhhao/fix/claude-oauth-compat
fix(网关): 完善 Claude OAuth/Claude Code 兼容
This commit is contained in:
@@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
|
|||||||
@@ -9,11 +9,26 @@ const (
|
|||||||
BetaClaudeCode = "claude-code-20250219"
|
BetaClaudeCode = "claude-code-20250219"
|
||||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||||
|
BetaTokenCounting = "token-counting-2024-11-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||||
|
//
|
||||||
|
// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic"
|
||||||
|
// Claude Code for non-Claude-Code clients, we must include the claude-code beta
|
||||||
|
// even if the request doesn't use tools, otherwise upstream may reject the
|
||||||
|
// request as a non-Claude-Code API request.
|
||||||
|
const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||||
|
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||||
|
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||||
|
|
||||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
@@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
|||||||
|
|
||||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||||
var DefaultHeaders = map[string]string{
|
var DefaultHeaders = map[string]string{
|
||||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
// Keep these in sync with recent Claude CLI traffic to reduce the chance
|
||||||
|
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
|
||||||
|
"User-Agent": "claude-cli/2.1.22 (external, cli)",
|
||||||
"X-Stainless-Lang": "js",
|
"X-Stainless-Lang": "js",
|
||||||
"X-Stainless-Package-Version": "0.52.0",
|
"X-Stainless-Package-Version": "0.70.0",
|
||||||
"X-Stainless-OS": "Linux",
|
"X-Stainless-OS": "Linux",
|
||||||
"X-Stainless-Arch": "x64",
|
"X-Stainless-Arch": "arm64",
|
||||||
"X-Stainless-Runtime": "node",
|
"X-Stainless-Runtime": "node",
|
||||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
"X-Stainless-Runtime-Version": "v24.13.0",
|
||||||
"X-Stainless-Retry-Count": "0",
|
"X-Stainless-Retry-Count": "0",
|
||||||
"X-Stainless-Timeout": "60",
|
"X-Stainless-Timeout": "600",
|
||||||
"X-App": "cli",
|
"X-App": "cli",
|
||||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||||
}
|
}
|
||||||
@@ -79,3 +96,39 @@ func DefaultModelIDs() []string {
|
|||||||
|
|
||||||
// DefaultTestModel 测试时使用的默认模型
|
// DefaultTestModel 测试时使用的默认模型
|
||||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
|
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
|
||||||
|
var ModelIDOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||||
|
"claude-opus-4-5": "claude-opus-4-5-20251101",
|
||||||
|
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
|
||||||
|
var ModelIDReverseOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeModelID 根据 Claude OAuth 规则映射模型
|
||||||
|
func NormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// DenormalizeModelID 将上游模型 ID 转换为短名
|
||||||
|
func DenormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDReverseOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|||||||
@@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetClaudeUserID() string {
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) {
|
|||||||
"system": []map[string]any{
|
"system": []map[string]any{
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
"text": claudeCodeSystemPrompt,
|
||||||
"cache_control": map[string]string{
|
"cache_control": map[string]string{
|
||||||
"type": "ephemeral",
|
"type": "ephemeral",
|
||||||
},
|
},
|
||||||
|
|||||||
23
backend/internal/service/gateway_beta_test.go
Normal file
23
backend/internal/service/gateway_beta_test.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeAnthropicBeta(t *testing.T) {
|
||||||
|
got := mergeAnthropicBeta(
|
||||||
|
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||||
|
"foo, oauth-2025-04-20,bar, foo",
|
||||||
|
)
|
||||||
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
|
||||||
|
got := mergeAnthropicBeta(
|
||||||
|
[]string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"},
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
|
||||||
|
}
|
||||||
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
62
backend/internal/service/gateway_oauth_metadata_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Stream: true,
|
||||||
|
MetadataUserID: "",
|
||||||
|
System: nil,
|
||||||
|
Messages: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format
|
||||||
|
|
||||||
|
got := svc.buildOAuthMetadataUserID(parsed, account, fp)
|
||||||
|
require.NotEmpty(t, got)
|
||||||
|
|
||||||
|
// Legacy format: user_{client}_account__session_{uuid}
|
||||||
|
re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`)
|
||||||
|
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) {
|
||||||
|
svc := &GatewayService{}
|
||||||
|
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Model: "claude-sonnet-4-5",
|
||||||
|
Stream: true,
|
||||||
|
MetadataUserID: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"account_uuid": "acc-uuid",
|
||||||
|
"claude_user_id": "clientid123",
|
||||||
|
"anthropic_user_id": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := svc.buildOAuthMetadataUserID(parsed, account, nil)
|
||||||
|
require.NotEmpty(t, got)
|
||||||
|
|
||||||
|
// New format: user_{client}_account_{account_uuid}_session_{uuid}
|
||||||
|
re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`)
|
||||||
|
require.True(t, re.MatchString(got), "unexpected user_id format: %s", got)
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||||
|
claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
body string
|
body string
|
||||||
@@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
system: "Custom prompt",
|
system: "Custom prompt",
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Custom prompt",
|
wantSecondText: claudePrefix + "\n\nCustom prompt",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "string system equals Claude Code prompt",
|
name: "string system equals Claude Code prompt",
|
||||||
@@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
// Claude Code + Custom = 2
|
// Claude Code + Custom = 2
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Custom",
|
wantSecondText: claudePrefix + "\n\nCustom",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||||
@@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) {
|
|||||||
// Claude Code at start + Other = 2 (deduped)
|
// Claude Code at start + Other = 2 (deduped)
|
||||||
wantSystemLen: 2,
|
wantSystemLen: 2,
|
||||||
wantFirstText: claudeCodeSystemPrompt,
|
wantFirstText: claudeCodeSystemPrompt,
|
||||||
wantSecondText: "Other",
|
wantSecondText: claudePrefix + "\n\nOther",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty array",
|
name: "empty array",
|
||||||
|
|||||||
21
backend/internal/service/gateway_sanitize_test.go
Normal file
21
backend/internal/service/gateway_sanitize_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) {
|
||||||
|
in := "You are OpenCode, the best coding agent on the planet."
|
||||||
|
got := sanitizeSystemText(in)
|
||||||
|
require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) {
|
||||||
|
in := "OpenCode and opencode are mentioned."
|
||||||
|
got := sanitizeToolDescription(in)
|
||||||
|
// We no longer rewrite tool descriptions; only redact obvious path leaks.
|
||||||
|
require.Equal(t, in, got)
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -26,13 +26,13 @@ var (
|
|||||||
|
|
||||||
// 默认指纹值(当客户端未提供时使用)
|
// 默认指纹值(当客户端未提供时使用)
|
||||||
var defaultFingerprint = Fingerprint{
|
var defaultFingerprint = Fingerprint{
|
||||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
UserAgent: "claude-cli/2.1.22 (external, cli)",
|
||||||
StainlessLang: "js",
|
StainlessLang: "js",
|
||||||
StainlessPackageVersion: "0.52.0",
|
StainlessPackageVersion: "0.70.0",
|
||||||
StainlessOS: "Linux",
|
StainlessOS: "Linux",
|
||||||
StainlessArch: "x64",
|
StainlessArch: "arm64",
|
||||||
StainlessRuntime: "node",
|
StainlessRuntime: "node",
|
||||||
StainlessRuntimeVersion: "v22.14.0",
|
StainlessRuntimeVersion: "v24.13.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fingerprint represents account fingerprint data
|
// Fingerprint represents account fingerprint data
|
||||||
@@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseUserAgentVersion 解析user-agent版本号
|
// parseUserAgentVersion 解析user-agent版本号
|
||||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||||
// 匹配 xxx/x.y.z 格式
|
// 匹配 xxx/x.y.z 格式
|
||||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||||
|
|||||||
@@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
// 记录上次收到上游数据的时间,用于控制 keepalive 发送频率
|
||||||
lastDataAt := time.Now()
|
lastDataAt := time.Now()
|
||||||
|
|
||||||
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
|
// 仅发送一次错误事件,避免多次写入导致协议混乱。
|
||||||
|
// 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema;
|
||||||
|
// 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。
|
||||||
errorEventSent := false
|
errorEventSent := false
|
||||||
|
clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
|
||||||
sendErrorEvent := func(reason string) {
|
sendErrorEvent := func(reason string) {
|
||||||
if errorEventSent {
|
if errorEventSent || clientDisconnected {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errorEventSent = true
|
errorEventSent = true
|
||||||
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
|
payload := map[string]any{
|
||||||
flusher.Flush()
|
"type": "error",
|
||||||
|
"sequence_number": 0,
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "upstream_error",
|
||||||
|
"message": reason,
|
||||||
|
"code": reason,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if b, err := json.Marshal(payload); err == nil {
|
||||||
|
_, _ = fmt.Fprintf(w, "data: %s\n\n", b)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
@@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
}
|
}
|
||||||
if ev.err != nil {
|
if ev.err != nil {
|
||||||
|
// 客户端断开/取消请求时,上游读取往往会返回 context canceled。
|
||||||
|
// /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
|
||||||
|
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||||
|
log.Printf("Context canceled during streaming, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
|
// 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
if errors.Is(ev.err, bufio.ErrTooLong) {
|
if errors.Is(ev.err, bufio.ErrTooLong) {
|
||||||
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
|
||||||
sendErrorEvent("response_too_large")
|
sendErrorEvent("response_too_large")
|
||||||
@@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
|
|
||||||
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
|
||||||
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
|
||||||
|
data = correctedData
|
||||||
line = "data: " + correctedData
|
line = "data: " + correctedData
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward line
|
// 写入客户端(客户端断开后继续 drain 上游)
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
// Record first token time
|
// Record first token time
|
||||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||||
@@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
s.parseSSEUsage(data, usage)
|
s.parseSSEUsage(data, usage)
|
||||||
} else {
|
} else {
|
||||||
// Forward non-data lines as-is
|
// Forward non-data lines as-is
|
||||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
if !clientDisconnected {
|
||||||
sendErrorEvent("write_failed")
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
} else {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-intervalCh:
|
case <-intervalCh:
|
||||||
@@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
if time.Since(lastRead) < streamInterval {
|
if time.Since(lastRead) < streamInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if clientDisconnected {
|
||||||
|
log.Printf("Upstream timeout after client disconnect, returning collected usage")
|
||||||
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||||
|
}
|
||||||
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
|
||||||
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
// 处理流超时,可能标记账户为临时不可调度或错误状态
|
||||||
if s.rateLimitService != nil {
|
if s.rateLimitService != nil {
|
||||||
@@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||||
|
|
||||||
case <-keepaliveCh:
|
case <-keepaliveCh:
|
||||||
|
if clientDisconnected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if time.Since(lastDataAt) < keepaliveInterval {
|
if time.Since(lastDataAt) < keepaliveInterval {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
|
||||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
clientDisconnected = true
|
||||||
|
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,25 @@ type stubConcurrencyCache struct {
|
|||||||
skipDefaultLoad bool
|
skipDefaultLoad bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cancelReadCloser struct{}
|
||||||
|
|
||||||
|
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
|
||||||
|
func (c cancelReadCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
type failingGinWriter struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
failAfter int
|
||||||
|
writes int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||||
|
if w.writes >= w.failAfter {
|
||||||
|
return 0, errors.New("write failed")
|
||||||
|
}
|
||||||
|
w.writes++
|
||||||
|
return w.ResponseWriter.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
if c.acquireResults != nil {
|
if c.acquireResults != nil {
|
||||||
if result, ok := c.acquireResults[accountID]; ok {
|
if result, ok := c.acquireResults[accountID]; ok {
|
||||||
@@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
|
|||||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||||
t.Fatalf("expected stream timeout error, got %v", err)
|
t.Fatalf("expected stream timeout error, got %v", err)
|
||||||
}
|
}
|
||||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
|
||||||
|
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: cancelReadCloser{},
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{
|
||||||
|
StreamDataIntervalTimeout: 0,
|
||||||
|
StreamKeepaliveInterval: 0,
|
||||||
|
MaxLineSize: defaultMaxLineSize,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &OpenAIGatewayService{cfg: cfg}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: pr,
|
||||||
|
Header: http.Header{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() { _ = pw.Close() }()
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
|
||||||
|
_, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||||
|
_ = pr.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error, got %v", err)
|
||||||
|
}
|
||||||
|
if result == nil || result.usage == nil {
|
||||||
|
t.Fatalf("expected usage result")
|
||||||
|
}
|
||||||
|
if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
|
||||||
|
t.Fatalf("unexpected usage: %+v", *result.usage)
|
||||||
|
}
|
||||||
|
if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
|
||||||
|
t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) {
|
|||||||
if !errors.Is(err, bufio.ErrTooLong) {
|
if !errors.Is(err, bufio.ErrTooLong) {
|
||||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||||
}
|
}
|
||||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user