diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 842242ca..f29da43f 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -779,6 +779,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + // 检查是否为 Claude Code 客户端,设置到 context 中 + SetClaudeCodeClientContext(c, body) + setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body) diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index d1a56a84..8b3441dc 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -9,11 +9,26 @@ const ( BetaClaudeCode = "claude-code-20250219" BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" + BetaTokenCounting = "token-counting-2024-11-01" ) // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming +// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header +// +// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic" +// Claude Code for non-Claude-Code clients, we must include the claude-code beta +// even if the request doesn't use tools, otherwise upstream may reject the +// request as a non-Claude-Code API request. +const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + +// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header +const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + +// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header +const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking @@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking // DefaultHeaders 是 Claude Code 客户端默认请求头。 var DefaultHeaders = map[string]string{ - "User-Agent": "claude-cli/2.0.62 (external, cli)", + // Keep these in sync with recent Claude CLI traffic to reduce the chance + // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage. + "User-Agent": "claude-cli/2.1.22 (external, cli)", "X-Stainless-Lang": "js", - "X-Stainless-Package-Version": "0.52.0", + "X-Stainless-Package-Version": "0.70.0", "X-Stainless-OS": "Linux", - "X-Stainless-Arch": "x64", + "X-Stainless-Arch": "arm64", "X-Stainless-Runtime": "node", - "X-Stainless-Runtime-Version": "v22.14.0", + "X-Stainless-Runtime-Version": "v24.13.0", "X-Stainless-Retry-Count": "0", - "X-Stainless-Timeout": "60", + "X-Stainless-Timeout": "600", "X-App": "cli", "Anthropic-Dangerous-Direct-Browser-Access": "true", } @@ -79,3 +96,39 @@ func DefaultModelIDs() []string { // DefaultTestModel 测试时使用的默认模型 const DefaultTestModel = "claude-sonnet-4-5-20250929" + +// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射 +var ModelIDOverrides = map[string]string{ + "claude-sonnet-4-5": "claude-sonnet-4-5-20250929", + "claude-opus-4-5": "claude-opus-4-5-20251101", + "claude-haiku-4-5": "claude-haiku-4-5-20251001", +} + +// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名 +var ModelIDReverseOverrides = map[string]string{ + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-opus-4-5-20251101": "claude-opus-4-5", + "claude-haiku-4-5-20251001": "claude-haiku-4-5", +} + +// NormalizeModelID 根据 Claude OAuth 规则映射模型 +func NormalizeModelID(id string) string { + if id == "" { + return id + } + if mapped, ok := ModelIDOverrides[id]; ok { + return mapped + } + return id +} + +// DenormalizeModelID 将上游模型 ID 转换为短名 +func DenormalizeModelID(id string) string { + if id == "" { + return id + } + if mapped, ok := ModelIDReverseOverrides[id]; ok { + return mapped + } + return id +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 182e0161..7b958838 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string { return "" } +func (a *Account) GetClaudeUserID() string { + if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" { + return v + } + return "" +} + func (a *Account) IsCustomErrorCodesEnabled() bool { if a.Type != AccountTypeAPIKey || a.Credentials == nil { return false diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 46376c69..3290fe52 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) { "system": []map[string]any{ { "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", + "text": claudeCodeSystemPrompt, "cache_control": map[string]string{ "type": "ephemeral", }, diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go new file mode 100644 index 00000000..dd58c183 --- /dev/null +++ b/backend/internal/service/gateway_beta_test.go @@ -0,0 +1,23 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMergeAnthropicBeta(t *testing.T) { + got := mergeAnthropicBeta( + []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}, + "foo, oauth-2025-04-20,bar, foo", + ) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got) +} + +func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { + got := mergeAnthropicBeta( + []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}, + "", + ) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) +} diff --git a/backend/internal/service/gateway_oauth_metadata_test.go b/backend/internal/service/gateway_oauth_metadata_test.go new file mode 100644 index 00000000..ed6f1887 --- /dev/null +++ b/backend/internal/service/gateway_oauth_metadata_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Model: "claude-sonnet-4-5", + Stream: true, + MetadataUserID: "", + System: nil, + Messages: nil, + } + + account := &Account{ + ID: 123, + Type: AccountTypeOAuth, + Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id + } + + fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format + + got := svc.buildOAuthMetadataUserID(parsed, account, fp) + require.NotEmpty(t, got) + + // Legacy format: user_{client}_account__session_{uuid} + re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`) + require.True(t, re.MatchString(got), "unexpected user_id format: %s", got) +} + +func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Model: "claude-sonnet-4-5", + Stream: true, + MetadataUserID: "", + } + + account := &Account{ + ID: 123, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "account_uuid": "acc-uuid", + "claude_user_id": "clientid123", + "anthropic_user_id": "", + }, + } + + got := svc.buildOAuthMetadataUserID(parsed, account, nil) + require.NotEmpty(t, got) + + // New format: user_{client}_account_{account_uuid}_session_{uuid} + re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`) + require.True(t, re.MatchString(got), "unexpected user_id format: %s", got) +} diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index b056f8fa..52c75d1d 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -2,6 +2,7 @@ package service import ( "encoding/json" + "strings" "testing" "github.com/stretchr/testify/require" @@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { } func TestInjectClaudeCodePrompt(t *testing.T) { + claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt) + tests := []struct { name string body string @@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { system: "Custom prompt", wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Custom prompt", + wantSecondText: claudePrefix + "\n\nCustom prompt", }, { name: "string system equals Claude Code prompt", @@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { // Claude Code + Custom = 2 wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Custom", + wantSecondText: claudePrefix + "\n\nCustom", }, { name: "array system with existing Claude Code prompt (should dedupe)", @@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { // Claude Code at start + Other = 2 (deduped) wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Other", + wantSecondText: claudePrefix + "\n\nOther", }, { name: "empty array", diff --git a/backend/internal/service/gateway_sanitize_test.go b/backend/internal/service/gateway_sanitize_test.go new file mode 100644 index 00000000..8fa971ca --- /dev/null +++ b/backend/internal/service/gateway_sanitize_test.go @@ -0,0 +1,21 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) { + in := "You are OpenCode, the best coding agent on the planet." + got := sanitizeSystemText(in) + require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got) +} + +func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) { + in := "OpenCode and opencode are mentioned." + got := sanitizeToolDescription(in) + // We no longer rewrite tool descriptions; only redact obvious path leaks. + require.Equal(t, in, got) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9125163a..f52cd2d8 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,12 +20,14 @@ import ( "strings" "sync/atomic" "time" + "unicode" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -37,8 +39,15 @@ const ( claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 40 * 1024 * 1024 - claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." - maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) + // to match real Claude CLI traffic as closely as possible. When we need a visual + // separator between system blocks, we add "\n\n" at concatenation time. + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." + maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 +) + +const ( + claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) func (s *GatewayService) debugModelRoutingEnabled() bool { @@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool { return v == "1" || v == "true" || v == "yes" || v == "on" } +func (s *GatewayService) debugClaudeMimicEnabled() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + func shortSessionHash(sessionHash string) string { if sessionHash == "" { return "" @@ -56,12 +70,178 @@ func shortSessionHash(sessionHash string) string { return sessionHash[:8] } +func redactAuthHeaderValue(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "" + } + // Keep scheme for debugging, redact secret. + if strings.HasPrefix(strings.ToLower(v), "bearer ") { + return "Bearer [redacted]" + } + return "[redacted]" +} + +func safeHeaderValueForLog(key string, v string) string { + key = strings.ToLower(strings.TrimSpace(key)) + switch key { + case "authorization", "x-api-key": + return redactAuthHeaderValue(v) + default: + return strings.TrimSpace(v) + } +} + +func extractSystemPreviewFromBody(body []byte) string { + if len(body) == 0 { + return "" + } + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return "" + } + + switch { + case sys.IsArray(): + for _, item := range sys.Array() { + if !item.IsObject() { + continue + } + if strings.EqualFold(item.Get("type").String(), "text") { + if t := item.Get("text").String(); strings.TrimSpace(t) != "" { + return t + } + } + } + return "" + case sys.Type == gjson.String: + return sys.String() + default: + return "" + } +} + +func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { + if req == nil { + return "" + } + + // Only log a minimal fingerprint to avoid leaking user content. + interesting := []string{ + "user-agent", + "x-app", + "anthropic-dangerous-direct-browser-access", + "anthropic-version", + "anthropic-beta", + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-retry-count", + "x-stainless-timeout", + "authorization", + "x-api-key", + "content-type", + "accept", + "x-stainless-helper-method", + } + + h := make([]string, 0, len(interesting)) + for _, k := range interesting { + if v := req.Header.Get(k); v != "" { + h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) + } + } + + metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) + sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) + + // Truncate preview to keep logs sane. + if len(sysPreview) > 300 { + sysPreview = sysPreview[:300] + "..." + } + sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") + sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") + + aid := int64(0) + aname := "" + if account != nil { + aid = account.ID + aname = account.Name + } + + return fmt.Sprintf( + "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", + req.URL.String(), + aid, + aname, + tokenType, + mimicClaudeCode, + metaUserID, + sysPreview, + strings.Join(h, " "), + ) +} + +func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { + line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) + if line == "" { + return + } + log.Printf("[ClaudeMimicDebug] %s", line) +} + +func isClaudeCodeCredentialScopeError(msg string) bool { + m := strings.ToLower(strings.TrimSpace(msg)) + if m == "" { + return false + } + return strings.Contains(m, "only authorized for use with claude code") && + strings.Contains(m, "cannot be used for other api requests") +} + // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`) + toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`) + toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`) + toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`) + modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`) + toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`) + toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`) + + claudeToolNameOverrides = map[string]string{ + "bash": "Bash", + "read": "Read", + "edit": "Edit", + "write": "Write", + "task": "Task", + "glob": "Glob", + "grep": "Grep", + "webfetch": "WebFetch", + "websearch": "WebSearch", + "todowrite": "TodoWrite", + "question": "AskUserQuestion", + } + openCodeToolOverrides = map[string]string{ + "Bash": "bash", + "Read": "read", + "Edit": "edit", + "Write": "write", + "Task": "task", + "Glob": "glob", + "Grep": "grep", + "WebFetch": "webfetch", + "WebSearch": "websearch", + "TodoWrite": "todowrite", + "AskUserQuestion": "question", + } // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -418,6 +598,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte return newBody } +type claudeOAuthNormalizeOptions struct { + injectMetadata bool + metadataUserID string + stripSystemCacheControl bool +} + +func stripToolPrefix(value string) string { + if value == "" { + return value + } + return toolPrefixRe.ReplaceAllString(value, "") +} + +func toPascalCase(value string) string { + if value == "" { + return value + } + normalized := toolNameBoundaryRe.ReplaceAllString(value, " ") + tokens := make([]string, 0) + for _, token := range strings.Fields(normalized) { + expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2") + parts := strings.Fields(expanded) + if len(parts) > 0 { + tokens = append(tokens, parts...) + } + } + if len(tokens) == 0 { + return value + } + var builder strings.Builder + for _, token := range tokens { + lower := strings.ToLower(token) + if lower == "" { + continue + } + runes := []rune(lower) + runes[0] = unicode.ToUpper(runes[0]) + _, _ = builder.WriteString(string(runes)) + } + return builder.String() +} + +func toSnakeCase(value string) string { + if value == "" { + return value + } + output := toolNameCamelRe.ReplaceAllString(value, "$1_$2") + output = toolNameBoundaryRe.ReplaceAllString(output, "_") + output = strings.Trim(output, "_") + return strings.ToLower(output) +} + +func normalizeToolNameForClaude(name string, cache map[string]string) string { + if name == "" { + return name + } + stripped := stripToolPrefix(name) + mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] + if !ok { + mapped = toPascalCase(stripped) + } + if mapped != "" && cache != nil && mapped != stripped { + cache[mapped] = stripped + } + if mapped == "" { + return stripped + } + return mapped +} + +func normalizeToolNameForOpenCode(name string, cache map[string]string) string { + if name == "" { + return name + } + stripped := stripToolPrefix(name) + if cache != nil { + if mapped, ok := cache[stripped]; ok { + return mapped + } + } + if mapped, ok := openCodeToolOverrides[stripped]; ok { + return mapped + } + return toSnakeCase(stripped) +} + +func normalizeParamNameForOpenCode(name string, cache map[string]string) string { + if name == "" { + return name + } + if cache != nil { + if mapped, ok := cache[name]; ok { + return mapped + } + } + return name +} + +// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). +// We intentionally avoid broad keyword replacement in system prompts to prevent +// accidentally changing user-provided instructions. +func sanitizeSystemText(text string) string { + if text == "" { + return text + } + // Some clients include a fixed OpenCode identity sentence. Anthropic may treat + // this as a non-Claude-Code fingerprint, so rewrite it to the canonical + // Claude Code banner before generic "OpenCode"/"opencode" replacements. + text = strings.ReplaceAll( + text, + "You are OpenCode, the best coding agent on the planet.", + strings.TrimSpace(claudeCodeSystemPrompt), + ) + return text +} + +func sanitizeToolDescription(description string) string { + if description == "" { + return description + } + description = toolDescAbsPathRe.ReplaceAllString(description, "[path]") + description = toolDescWinPathRe.ReplaceAllString(description, "[path]") + // Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings). + // Tool names/skill names may rely on exact wording, and rewriting can be misleading. + return description +} + +func normalizeToolInputSchema(inputSchema any, cache map[string]string) { + schema, ok := inputSchema.(map[string]any) + if !ok { + return + } + properties, ok := schema["properties"].(map[string]any) + if !ok { + return + } + + newProperties := make(map[string]any, len(properties)) + for key, value := range properties { + snakeKey := toSnakeCase(key) + newProperties[snakeKey] = value + if snakeKey != key && cache != nil { + cache[snakeKey] = key + } + } + schema["properties"] = newProperties + + if required, ok := schema["required"].([]any); ok { + newRequired := make([]any, 0, len(required)) + for _, item := range required { + name, ok := item.(string) + if !ok { + newRequired = append(newRequired, item) + continue + } + snakeName := toSnakeCase(name) + newRequired = append(newRequired, snakeName) + if snakeName != name && cache != nil { + cache[snakeName] = name + } + } + schema["required"] = newRequired + } +} + +func stripCacheControlFromSystemBlocks(system any) bool { + blocks, ok := system.([]any) + if !ok { + return false + } + changed := false + for _, item := range blocks { + block, ok := item.(map[string]any) + if !ok { + continue + } + if _, exists := block["cache_control"]; !exists { + continue + } + delete(block, "cache_control") + changed = true + } + return changed +} + +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) { + if len(body) == 0 { + return body, modelID, nil + } + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return body, modelID, nil + } + + toolNameMap := make(map[string]string) + + if system, ok := req["system"]; ok { + switch v := system.(type) { + case string: + sanitized := sanitizeSystemText(v) + if sanitized != v { + req["system"] = sanitized + } + case []any: + for _, item := range v { + block, ok := item.(map[string]any) + if !ok { + continue + } + if blockType, _ := block["type"].(string); blockType != "text" { + continue + } + text, ok := block["text"].(string) + if !ok || text == "" { + continue + } + sanitized := sanitizeSystemText(text) + if sanitized != text { + block["text"] = sanitized + } + } + } + } + + if rawModel, ok := req["model"].(string); ok { + normalized := claude.NormalizeModelID(rawModel) + if normalized != rawModel { + req["model"] = normalized + modelID = normalized + } + } + + if rawTools, exists := req["tools"]; exists { + switch tools := rawTools.(type) { + case []any: + for idx, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + continue + } + if name, ok := toolMap["name"].(string); ok { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized != "" && normalized != name { + toolMap["name"] = normalized + } + } + if desc, ok := toolMap["description"].(string); ok { + sanitized := sanitizeToolDescription(desc) + if sanitized != desc { + toolMap["description"] = sanitized + } + } + if schema, ok := toolMap["input_schema"]; ok { + normalizeToolInputSchema(schema, toolNameMap) + } + tools[idx] = toolMap + } + req["tools"] = tools + case map[string]any: + normalizedTools := make(map[string]any, len(tools)) + for name, value := range tools { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized == "" { + normalized = name + } + if toolMap, ok := value.(map[string]any); ok { + toolMap["name"] = normalized + if desc, ok := toolMap["description"].(string); ok { + sanitized := sanitizeToolDescription(desc) + if sanitized != desc { + toolMap["description"] = sanitized + } + } + if schema, ok := toolMap["input_schema"]; ok { + normalizeToolInputSchema(schema, toolNameMap) + } + normalizedTools[normalized] = toolMap + continue + } + normalizedTools[normalized] = value + } + req["tools"] = normalizedTools + } + } else { + req["tools"] = []any{} + } + + if messages, ok := req["messages"].([]any); ok { + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := msgMap["content"].([]any) + if !ok { + continue + } + for _, block := range content { + blockMap, ok := block.(map[string]any) + if !ok { + continue + } + if blockType, _ := blockMap["type"].(string); blockType != "tool_use" { + continue + } + if name, ok := blockMap["name"].(string); ok { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized != "" && normalized != name { + blockMap["name"] = normalized + } + } + } + } + } + + if opts.stripSystemCacheControl { + if system, ok := req["system"]; ok { + _ = stripCacheControlFromSystemBlocks(system) + } + } + + if opts.injectMetadata && opts.metadataUserID != "" { + metadata, ok := req["metadata"].(map[string]any) + if !ok { + metadata = map[string]any{} + req["metadata"] = metadata + } + if existing, ok := metadata["user_id"].(string); !ok || existing == "" { + metadata["user_id"] = opts.metadataUserID + } + } + + delete(req, "temperature") + delete(req, "tool_choice") + + newBody, err := json.Marshal(req) + if err != nil { + return body, modelID, toolNameMap + } + return newBody, modelID, toolNameMap +} + +func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { + if parsed == nil || account == nil { + return "" + } + if parsed.MetadataUserID != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + // Fall back to a random, well-formed client id so we can still satisfy + // Claude Code OAuth requirements when account metadata is incomplete. + userID = generateClientID() + } + + sessionHash := s.GenerateSessionHash(parsed) + sessionID := uuid.NewString() + if sessionHash != "" { + seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) + sessionID = generateSessionUUID(seed) + } + + // Prefer the newer format that includes account_uuid (if present), + // otherwise fall back to the legacy Claude Code format. + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + if accountUUID != "" { + return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) + } + return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) +} + +func generateSessionUUID(seed string) string { + if seed == "" { + return uuid.NewString() + } + hash := sha256.Sum256([]byte(seed)) + bytes := hash[:16] + bytes[6] = (bytes[6] & 0x0f) | 0x40 + bytes[8] = (bytes[8] & 0x3f) | 0x80 + return fmt.Sprintf("%x-%x-%x-%x-%x", + bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} + // SelectAccount 选择账号(粘性会话+优先级) func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -2021,6 +2589,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { return claudeCliUserAgentRe.MatchString(userAgent) } +func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { + if IsClaudeCodeClient(ctx) { + return true + } + if parsed == nil || c == nil { + return false + } + return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) +} + // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) func systemIncludesClaudeCodePrompt(system any) bool { @@ -2057,6 +2635,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { "text": claudeCodeSystemPrompt, "cache_control": map[string]string{"type": "ephemeral"}, } + // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code + // banner, it also prefixes the next system instruction with the same banner plus + // a blank line. This helps when upstream concatenates system instructions. + claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) var newSystem []any @@ -2064,19 +2646,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { case nil: newSystem = []any{claudeCodeBlock} case string: - if v == "" || v == claudeCodeSystemPrompt { + // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. + if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { newSystem = []any{claudeCodeBlock} } else { - newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} + // Mirror opencode behavior: keep the banner as a separate system entry, + // but also prefix the next system text with the banner. + merged := v + if !strings.HasPrefix(v, claudeCodePrefix) { + merged = claudeCodePrefix + "\n\n" + v + } + newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} } case []any: newSystem = make([]any, 0, len(v)+1) newSystem = append(newSystem, claudeCodeBlock) + prefixedNext := false for _, item := range v { if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { continue } + // Prefix the first subsequent text system block once. + if !prefixedNext { + if blockType, _ := m["type"].(string); blockType == "text" { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + m["text"] = claudeCodePrefix + "\n\n" + text + prefixedNext = true + } + } + } } newSystem = append(newSystem, item) } @@ -2280,21 +2879,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream + originalModel := reqModel + var toolNameMap map[string]string - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) - // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - if account.IsOAuth() && - !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && - !strings.Contains(strings.ToLower(reqModel), "haiku") && - !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + if s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil && fp != nil { + if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = metadataUserID + } + } + } + + body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) // 应用模型映射(仅对apikey类型账号) - originalModel := reqModel if account.Type == AccountTypeAPIKey { mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { @@ -2326,10 +2942,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) // Capture upstream request body for ops retry of this attempt. c.Set(OpsUpstreamRequestBodyKey, string(body)) - + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err } @@ -2407,7 +3022,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -2439,7 +3054,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) + retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -2664,7 +3279,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var firstTokenMs *int var clientDisconnect bool if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -2677,7 +3292,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A firstTokenMs = streamResult.firstTokenMs clientDisconnect = streamResult.clientDisconnect } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) if err != nil { return nil, err } @@ -2694,7 +3309,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { @@ -2708,11 +3323,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + // OAuth账号:应用统一指纹 var fingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers @@ -2743,7 +3363,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // 白名单透传headers - for key, values := range c.Request.Header { + for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { @@ -2764,10 +3384,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } - - // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + applyClaudeOAuthHeaderDefaults(req, reqStream) + } + + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) + if tokenType == "oauth" { + if mimicClaudeCode { + // 非 Claude Code 客户端:按 opencode 的策略处理: + // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) + // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 + applyClaudeCodeMimicHeaders(req, reqStream) + + incomingBeta := req.Header.Get("anthropic-beta") + // Match real Claude CLI traffic (per mitmproxy reports): + // messages requests typically use only oauth + interleaved-thinking. + // Also drop claude-code beta if a downstream client added it. + requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} + drop := map[string]struct{}{claude.BetaClaudeCode: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + } else { + // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta + clientBetaHeader := req.Header.Get("anthropic-beta") + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) if requestNeedsBetaFeatures(body) { @@ -2777,6 +3417,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // Always capture a compact fingerprint line for later error diagnostics. + // We only print it when needed (or when the explicit debug flag is enabled). + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + return req, nil } @@ -2846,6 +3495,93 @@ func defaultAPIKeyBetaHeader(body []byte) string { return claude.APIKeyBetaHeader } +func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { + if req == nil { + return + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "application/json") + } + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + if req.Header.Get(key) == "" { + req.Header.Set(key, value) + } + } + if isStream && req.Header.Get("x-stainless-helper-method") == "" { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + +func mergeAnthropicBeta(required []string, incoming string) string { + seen := make(map[string]struct{}, len(required)+8) + out := make([]string, 0, len(required)+8) + + add := func(v string) { + v = strings.TrimSpace(v) + if v == "" { + return + } + if _, ok := seen[v]; ok { + return + } + seen[v] = struct{}{} + out = append(out, v) + } + + for _, r := range required { + add(r) + } + for _, p := range strings.Split(incoming, ",") { + add(p) + } + return strings.Join(out, ",") +} + +func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { + merged := mergeAnthropicBeta(required, incoming) + if merged == "" || len(drop) == 0 { + return merged + } + out := make([]string, 0, 8) + for _, p := range strings.Split(merged, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + +// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. +// This mirrors opencode-anthropic-auth behavior: do not trust downstream +// headers when using Claude Code-scoped OAuth credentials. +func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { + if req == nil { + return + } + // Start with the standard defaults (fill missing). + applyClaudeOAuthHeaderDefaults(req, isStream) + // Then force key headers to match Claude Code fingerprint regardless of what the client sent. + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + req.Header.Set(key, value) + } + // Real Claude CLI uses Accept: application/json (even for streaming). + req.Header.Set("accept", "application/json") + if isStream { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + func truncateForLog(b []byte, maxBytes int) string { if maxBytes <= 0 { maxBytes = 2048 @@ -2949,6 +3685,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + // Print a compact upstream request fingerprint when we hit the Claude Code OAuth + // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { @@ -3078,6 +3828,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes @@ -3130,7 +3893,7 @@ type streamingResult struct { clientDisconnect bool // 客户端是否在流式传输过程中断开 } -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3225,6 +3988,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + pendingEventLines := make([]string, 0, 4) + var toolInputBuffers map[int]string + if mimicClaudeCode { + toolInputBuffers = make(map[int]string) + } + + transformToolInputJSON := func(raw string) string { + if !mimicClaudeCode { + return raw + } + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + + var parsed any + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return replaceToolNamesInText(raw, toolNameMap) + } + + rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap) + if changed { + if bytes, err := json.Marshal(rewritten); err == nil { + return string(bytes) + } + } + return raw + } + + processSSEEvent := func(lines []string) ([]string, string, error) { + if len(lines) == 0 { + return nil, "", nil + } + + eventName := "" + dataLine := "" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") { + eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) + continue + } + if dataLine == "" && sseDataRe.MatchString(trimmed) { + dataLine = sseDataRe.ReplaceAllString(trimmed, "") + } + } + + if eventName == "error" { + return nil, dataLine, errors.New("have error in stream") + } + + if dataLine == "" { + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + } + + if dataLine == "[DONE]" { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil + } + + var event map[string]any + if err := json.Unmarshal([]byte(dataLine), &event); err != nil { + replaced := dataLine + if mimicClaudeCode { + replaced = replaceToolNamesInText(dataLine, toolNameMap) + } + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + eventType, _ := event["type"].(string) + if eventName == "" { + eventName = eventType + } + + if needModelReplace { + if msg, ok := event["message"].(map[string]any); ok { + if model, ok := msg["model"].(string); ok && model == mappedModel { + msg["model"] = originalModel + } + } + } + + if mimicClaudeCode && eventType == "content_block_delta" { + if delta, ok := event["delta"].(map[string]any); ok { + if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuffers[index] += partial + } + } + return nil, dataLine, nil + } + } + } + + if mimicClaudeCode && eventType == "content_block_stop" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if buffered := toolInputBuffers[index]; buffered != "" { + delete(toolInputBuffers, index) + + transformed := transformToolInputJSON(buffered) + synthetic := map[string]any{ + "type": "content_block_delta", + "index": index, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": transformed, + }, + } + + synthBytes, synthErr := json.Marshal(synthetic) + if synthErr == nil { + synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n" + + rewriteToolNamesInValue(event, toolNameMap) + stopBytes, stopErr := json.Marshal(event) + if stopErr == nil { + stopBlock := "" + if eventName != "" { + stopBlock = "event: " + eventName + "\n" + } + stopBlock += "data: " + string(stopBytes) + "\n\n" + return []string{synthBlock, stopBlock}, string(stopBytes), nil + } + } + } + } + } + + if mimicClaudeCode { + rewriteToolNamesInValue(event, toolNameMap) + } + newData, err := json.Marshal(event) + if err != nil { + replaced := dataLine + if mimicClaudeCode { + replaced = replaceToolNamesInText(dataLine, toolNameMap) + } + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + string(newData) + "\n\n" + return []string{block}, string(newData), nil + } + for { select { case ev, ok := <-events: @@ -3253,42 +4181,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line - if line == "event: error" { - // 上游返回错误事件,如果客户端已断开仍返回已收集的 usage - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + trimmed := strings.TrimSpace(line) + + if trimmed == "" { + if len(pendingEventLines) == 0 { + continue } - return nil, errors.New("have error in stream") + + outputBlocks, data, err := processSSEEvent(pendingEventLines) + pendingEventLines = pendingEventLines[:0] + if err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return nil, err + } + + for _, block := range outputBlocks { + if !clientDisconnected { + if _, werr := fmt.Fprint(w, block); werr != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + break + } + flusher.Flush() + } + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } + } + continue } - // Extract data from SSE line (supports both "data: " and "data:" formats) - var data string - if sseDataRe.MatchString(line) { - data = sseDataRe.ReplaceAllString(line, "") - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - } - - // 写入客户端(统一处理 data 行和非 data 行) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // 无论客户端是否断开,都解析 usage(仅对 data 行) - if data != "" { - if firstTokenMs == nil && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } + pendingEventLines = append(pendingEventLines, line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -3312,43 +4241,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } -// replaceModelInSSELine 替换SSE数据行中的model字段 -func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !sseDataRe.MatchString(line) { - return line +func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) { + switch v := value.(type) { + case map[string]any: + changed := false + rewritten := make(map[string]any, len(v)) + for key, item := range v { + newKey := normalizeParamNameForOpenCode(key, cache) + newItem, childChanged := rewriteParamKeysInValue(item, cache) + if childChanged { + changed = true + } + if newKey != key { + changed = true + } + rewritten[newKey] = newItem + } + if !changed { + return value, false + } + return rewritten, true + case []any: + changed := false + rewritten := make([]any, len(v)) + for idx, item := range v { + newItem, childChanged := rewriteParamKeysInValue(item, cache) + if childChanged { + changed = true + } + rewritten[idx] = newItem + } + if !changed { + return value, false + } + return rewritten, true + default: + return value, false } - data := sseDataRe.ReplaceAllString(line, "") - if data == "" || data == "[DONE]" { - return line +} + +func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool { + switch v := value.(type) { + case map[string]any: + changed := false + if blockType, _ := v["type"].(string); blockType == "tool_use" { + if name, ok := v["name"].(string); ok { + mapped := normalizeToolNameForOpenCode(name, toolNameMap) + if mapped != name { + v["name"] = mapped + changed = true + } + } + if input, ok := v["input"].(map[string]any); ok { + rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap) + if inputChanged { + if m, ok := rewrittenInput.(map[string]any); ok { + v["input"] = m + changed = true + } + } + } + } + for _, item := range v { + if rewriteToolNamesInValue(item, toolNameMap) { + changed = true + } + } + return changed + case []any: + changed := false + for _, item := range v { + if rewriteToolNamesInValue(item, toolNameMap) { + changed = true + } + } + return changed + default: + return false + } +} + +func replaceToolNamesInText(text string, toolNameMap map[string]string) string { + if text == "" { + return text + } + output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string { + submatches := toolNameFieldRe.FindStringSubmatch(match) + if len(submatches) < 2 { + return match + } + name := submatches[1] + mapped := normalizeToolNameForOpenCode(name, toolNameMap) + if mapped == name { + return match + } + return strings.Replace(match, name, mapped, 1) + }) + output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string { + submatches := modelFieldRe.FindStringSubmatch(match) + if len(submatches) < 2 { + return match + } + model := submatches[1] + mapped := claude.DenormalizeModelID(model) + if mapped == model { + return match + } + return strings.Replace(match, model, mapped, 1) + }) + + for mapped, original := range toolNameMap { + if mapped == "" || original == "" || mapped == original { + continue + } + output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":") + output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":") } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // 只替换 message_start 事件中的 message.model - if event["type"] != "message_start" { - return line - } - - msg, ok := event["message"].(map[string]any) - if !ok { - return line - } - - model, ok := msg["model"].(string) - if !ok || model != fromModel { - return line - } - - msg["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - - return "data: " + string(newData) + return output } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { @@ -3394,7 +4404,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { } } -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3415,6 +4425,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } + if mimicClaudeCode { + body = s.replaceToolNamesInResponseBody(body, toolNameMap) + } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -3452,6 +4465,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo return newBody } +func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte { + if len(body) == 0 { + return body + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + replaced := replaceToolNamesInText(string(body), toolNameMap) + if replaced == string(body) { + return body + } + return []byte(replaced) + } + if !rewriteToolNamesInValue(resp, toolNameMap) { + return body + } + newBody, err := json.Marshal(resp) + if err != nil { + return body + } + return newBody +} + // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult @@ -3773,6 +4808,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body := parsed.Body reqModel := parsed.Model + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + } + // Antigravity 账户不支持 count_tokens 转发,直接返回空值 if account.Platform == PlatformAntigravity { c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) @@ -3799,7 +4842,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err @@ -3832,7 +4875,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -3897,7 +4940,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == AccountTypeAPIKey { @@ -3911,10 +4954,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + // OAuth 账号:应用统一指纹和重写 userID // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 if account.IsOAuth() && s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { @@ -3938,7 +4986,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // 白名单透传 headers - for key, values := range c.Request.Header { + for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { @@ -3949,7 +4997,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用指纹到请求头 if account.IsOAuth() && s.identityService != nil { - fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if fp != nil { s.identityService.ApplyFingerprint(req, fp) } @@ -3962,10 +5010,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req, false) + } // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + if mimicClaudeCode { + applyClaudeCodeMimicHeaders(req, false) + + incomingBeta := req.Header.Get("anthropic-beta") + requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} + req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + } else { + clientBetaHeader := req.Header.Get("anthropic-beta") + if clientBetaHeader == "" { + req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader) + } else { + beta := s.getBetaHeader(modelID, clientBetaHeader) + if !strings.Contains(beta, claude.BetaTokenCounting) { + beta = beta + "," + claude.BetaTokenCounting + } + req.Header.Set("anthropic-beta", beta) + } + } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:与 messages 同步的按需 beta 注入(默认关闭) if requestNeedsBetaFeatures(body) { @@ -3975,6 +5043,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + return req, nil } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index e2e723b0..a620ac4d 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -26,13 +26,13 @@ var ( // 默认指纹值(当客户端未提供时使用) var defaultFingerprint = Fingerprint{ - UserAgent: "claude-cli/2.0.62 (external, cli)", + UserAgent: "claude-cli/2.1.22 (external, cli)", StainlessLang: "js", - StainlessPackageVersion: "0.52.0", + StainlessPackageVersion: "0.70.0", StainlessOS: "Linux", - StainlessArch: "x64", + StainlessArch: "arm64", StainlessRuntime: "node", - StainlessRuntimeVersion: "v22.14.0", + StainlessRuntimeVersion: "v24.13.0", } // Fingerprint represents account fingerprint data @@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string { } // parseUserAgentVersion 解析user-agent版本号 -// 例如:claude-cli/2.0.62 -> (2, 0, 62) +// 例如:claude-cli/2.1.2 -> (2, 1, 2) func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { // 匹配 xxx/x.y.z 格式 matches := userAgentVersionRegex.FindStringSubmatch(ua) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 289a13af..b1866dee 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1260,15 +1260,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 lastDataAt := time.Now() - // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + // 仅发送一次错误事件,避免多次写入导致协议混乱。 + // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; + // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false + clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || clientDisconnected { return } errorEventSent = true - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) - flusher.Flush() + payload := map[string]any{ + "type": "error", + "sequence_number": 0, + "error": map[string]any{ + "type": "upstream_error", + "message": reason, + "code": reason, + }, + } + if b, err := json.Marshal(payload); err == nil { + _, _ = fmt.Fprintf(w, "data: %s\n\n", b) + flusher.Flush() + } } needModelReplace := originalModel != mappedModel @@ -1280,6 +1294,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if ev.err != nil { + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -1303,15 +1328,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { + data = correctedData line = "data: " + correctedData } - // Forward line - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() // Record first token time if firstTokenMs == nil && data != "" && data != "[DONE]" { @@ -1321,11 +1350,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp s.parseSSEUsage(data, usage) } else { // Forward non-data lines as-is - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() } case <-intervalCh: @@ -1333,6 +1365,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastRead) < streamInterval { continue } + if clientDisconnected { + log.Printf("Upstream timeout after client disconnect, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { @@ -1342,11 +1378,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } if _, err := fmt.Fprint(w, ":\n\n"); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + continue } flusher.Flush() } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1912e244..ae69a986 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -59,6 +59,25 @@ type stubConcurrencyCache struct { skipDefaultLoad bool } +type cancelReadCloser struct{} + +func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } +func (c cancelReadCloser) Close() error { return nil } + +type failingGinWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *failingGinWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { if c.acquireResults != nil { if result, ok := c.acquireResults[accountID]; ok { @@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { t.Fatalf("expected stream timeout error, got %v", err) } - if !strings.Contains(rec.Body.String(), "stream_timeout") { - t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) + if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") { + t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: cancelReadCloser{}, + Header: http.Header{}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if result == nil || result.usage == nil { + t.Fatalf("expected usage result") + } + if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 { + t.Fatalf("unexpected usage: %+v", *result.usage) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) } } @@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) { if !errors.Is(err, bufio.ErrTooLong) { t.Fatalf("expected ErrTooLong, got %v", err) } - if !strings.Contains(rec.Body.String(), "response_too_large") { - t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) + if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") { + t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) } }