fix(网关): 对齐 Claude OAuth 请求适配
This commit is contained in:
@@ -25,15 +25,15 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"User-Agent": "claude-cli/2.1.2 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
"X-Stainless-Package-Version": "0.52.0",
|
||||
"X-Stainless-Package-Version": "0.70.0",
|
||||
"X-Stainless-OS": "Linux",
|
||||
"X-Stainless-Arch": "x64",
|
||||
"X-Stainless-Runtime": "node",
|
||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
||||
"X-Stainless-Runtime-Version": "v24.3.0",
|
||||
"X-Stainless-Retry-Count": "0",
|
||||
"X-Stainless-Timeout": "60",
|
||||
"X-Stainless-Timeout": "600",
|
||||
"X-App": "cli",
|
||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||
}
|
||||
@@ -79,3 +79,39 @@ func DefaultModelIDs() []string {
|
||||
|
||||
// DefaultTestModel 测试时使用的默认模型
|
||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||
|
||||
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
|
||||
var ModelIDOverrides = map[string]string{
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-5": "claude-opus-4-5-20251101",
|
||||
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
|
||||
}
|
||||
|
||||
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
|
||||
var ModelIDReverseOverrides = map[string]string{
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
|
||||
}
|
||||
|
||||
// NormalizeModelID 根据 Claude OAuth 规则映射模型
|
||||
func NormalizeModelID(id string) string {
|
||||
if id == "" {
|
||||
return id
|
||||
}
|
||||
if mapped, ok := ModelIDOverrides[id]; ok {
|
||||
return mapped
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// DenormalizeModelID 将上游模型 ID 转换为短名
|
||||
func DenormalizeModelID(id string) string {
|
||||
if id == "" {
|
||||
return id
|
||||
}
|
||||
if mapped, ok := ModelIDReverseOverrides[id]; ok {
|
||||
return mapped
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
@@ -18,12 +18,14 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -60,6 +62,36 @@ var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
|
||||
claudeToolNameOverrides = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"edit": "Edit",
|
||||
"write": "Write",
|
||||
"task": "Task",
|
||||
"glob": "Glob",
|
||||
"grep": "Grep",
|
||||
"webfetch": "WebFetch",
|
||||
"websearch": "WebSearch",
|
||||
"todowrite": "TodoWrite",
|
||||
"question": "AskUserQuestion",
|
||||
}
|
||||
openCodeToolOverrides = map[string]string{
|
||||
"Bash": "bash",
|
||||
"Read": "read",
|
||||
"Edit": "edit",
|
||||
"Write": "write",
|
||||
"Task": "task",
|
||||
"Glob": "glob",
|
||||
"Grep": "grep",
|
||||
"WebFetch": "webfetch",
|
||||
"WebSearch": "websearch",
|
||||
"TodoWrite": "todowrite",
|
||||
"AskUserQuestion": "question",
|
||||
}
|
||||
|
||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||
@@ -365,6 +397,268 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
|
||||
return newBody
|
||||
}
|
||||
|
||||
type claudeOAuthNormalizeOptions struct {
|
||||
injectMetadata bool
|
||||
metadataUserID string
|
||||
stripSystemCacheControl bool
|
||||
}
|
||||
|
||||
func stripToolPrefix(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
return toolPrefixRe.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
func toPascalCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
|
||||
tokens := make([]string, 0)
|
||||
for _, token := range strings.Fields(normalized) {
|
||||
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
|
||||
parts := strings.Fields(expanded)
|
||||
if len(parts) > 0 {
|
||||
tokens = append(tokens, parts...)
|
||||
}
|
||||
}
|
||||
if len(tokens) == 0 {
|
||||
return value
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, token := range tokens {
|
||||
lower := strings.ToLower(token)
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
runes := []rune(lower)
|
||||
runes[0] = unicode.ToUpper(runes[0])
|
||||
builder.WriteString(string(runes))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func toSnakeCase(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
|
||||
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
|
||||
output = strings.Trim(output, "_")
|
||||
return strings.ToLower(output)
|
||||
}
|
||||
|
||||
func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
|
||||
if !ok {
|
||||
mapped = toPascalCase(stripped)
|
||||
}
|
||||
if mapped != "" && cache != nil && mapped != stripped {
|
||||
cache[mapped] = stripped
|
||||
}
|
||||
if mapped == "" {
|
||||
return stripped
|
||||
}
|
||||
return mapped
|
||||
}
|
||||
|
||||
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[name]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
if mapped, ok := openCodeToolOverrides[name]; ok {
|
||||
return mapped
|
||||
}
|
||||
return toSnakeCase(name)
|
||||
}
|
||||
|
||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
blocks, ok := system.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
changed := false
|
||||
for _, item := range blocks {
|
||||
block, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := block["cache_control"]; !exists {
|
||||
continue
|
||||
}
|
||||
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
continue
|
||||
}
|
||||
delete(block, "cache_control")
|
||||
changed = true
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
|
||||
if len(body) == 0 {
|
||||
return body, modelID, nil
|
||||
}
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
toolNameMap := make(map[string]string)
|
||||
|
||||
if rawModel, ok := req["model"].(string); ok {
|
||||
normalized := claude.NormalizeModelID(rawModel)
|
||||
if normalized != rawModel {
|
||||
req["model"] = normalized
|
||||
modelID = normalized
|
||||
}
|
||||
}
|
||||
|
||||
if rawTools, exists := req["tools"]; exists {
|
||||
switch tools := rawTools.(type) {
|
||||
case []any:
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if name, ok := toolMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
}
|
||||
}
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
req["tools"] = tools
|
||||
case map[string]any:
|
||||
normalizedTools := make(map[string]any, len(tools))
|
||||
for name, value := range tools {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized == "" {
|
||||
normalized = name
|
||||
}
|
||||
if toolMap, ok := value.(map[string]any); ok {
|
||||
if toolName, ok := toolMap["name"].(string); ok {
|
||||
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
|
||||
if mappedName != "" && mappedName != toolName {
|
||||
toolMap["name"] = mappedName
|
||||
}
|
||||
} else if normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
}
|
||||
normalizedTools[normalized] = toolMap
|
||||
continue
|
||||
}
|
||||
normalizedTools[normalized] = value
|
||||
}
|
||||
req["tools"] = normalizedTools
|
||||
}
|
||||
} else {
|
||||
req["tools"] = []any{}
|
||||
}
|
||||
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
|
||||
continue
|
||||
}
|
||||
if name, ok := blockMap["name"].(string); ok {
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
blockMap["name"] = normalized
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if opts.stripSystemCacheControl {
|
||||
if system, ok := req["system"]; ok {
|
||||
_ = stripCacheControlFromSystemBlocks(system)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.injectMetadata && opts.metadataUserID != "" {
|
||||
metadata, ok := req["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
metadata = map[string]any{}
|
||||
req["metadata"] = metadata
|
||||
}
|
||||
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
|
||||
metadata["user_id"] = opts.metadataUserID
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := req["temperature"]; ok {
|
||||
delete(req, "temperature")
|
||||
}
|
||||
if _, ok := req["tool_choice"]; ok {
|
||||
delete(req, "tool_choice")
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
}
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||
if parsed == nil || fp == nil || fp.ClientID == "" {
|
||||
return ""
|
||||
}
|
||||
if parsed.MetadataUserID != "" {
|
||||
return ""
|
||||
}
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID == "" {
|
||||
return ""
|
||||
}
|
||||
sessionHash := s.GenerateSessionHash(parsed)
|
||||
sessionID := uuid.NewString()
|
||||
if sessionHash != "" {
|
||||
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
|
||||
sessionID = generateSessionUUID(seed)
|
||||
}
|
||||
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
|
||||
}
|
||||
|
||||
func generateSessionUUID(seed string) string {
|
||||
if seed == "" {
|
||||
return uuid.NewString()
|
||||
}
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
bytes := hash[:16]
|
||||
bytes[6] = (bytes[6] & 0x0f) | 0x40
|
||||
bytes[8] = (bytes[8] & 0x3f) | 0x80
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||
}
|
||||
|
||||
// SelectAccount 选择账号(粘性会话+优先级)
|
||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
@@ -1906,21 +2200,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
originalModel := reqModel
|
||||
var toolNameMap map[string]string
|
||||
|
||||
if account.IsOAuth() {
|
||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||
if account.IsOAuth() &&
|
||||
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||
if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
}
|
||||
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
if s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil && fp != nil {
|
||||
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
|
||||
normalizeOpts.injectMetadata = true
|
||||
normalizeOpts.metadataUserID = metadataUserID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1948,10 +2257,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2029,7 +2337,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// also downgrade tool_use/tool_result blocks to text.
|
||||
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
@@ -2061,7 +2369,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr2 == nil {
|
||||
@@ -2278,7 +2586,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -2291,7 +2599,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2308,7 +2616,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
@@ -2377,6 +2685,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if tokenType == "oauth" {
|
||||
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||
}
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
@@ -2459,6 +2770,26 @@ func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
return claude.APIKeyBetaHeader
|
||||
}
|
||||
|
||||
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
if req.Header.Get("accept") == "" {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if req.Header.Get(key) == "" {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
|
||||
req.Header.Set("x-stainless-helper-method", "stream")
|
||||
}
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
@@ -2739,7 +3070,7 @@ type streamingResult struct {
|
||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -2832,6 +3163,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
rewriteTools := account.IsOAuth()
|
||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||
|
||||
for {
|
||||
@@ -2873,11 +3205,14 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
var data string
|
||||
if sseDataRe.MatchString(line) {
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
if rewriteTools {
|
||||
line = s.replaceToolNamesInSSELine(line, toolNameMap)
|
||||
}
|
||||
data = sseDataRe.ReplaceAllString(line, "")
|
||||
}
|
||||
|
||||
// 写入客户端(统一处理 data 行和非 data 行)
|
||||
@@ -2960,6 +3295,61 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
if blockType, _ := v["type"].(string); blockType == "tool_use" {
|
||||
if name, ok := v["name"].(string); ok {
|
||||
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||
if mapped != name {
|
||||
v["name"] = mapped
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
case []any:
|
||||
changed := false
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
|
||||
if !sseDataRe.MatchString(line) {
|
||||
return line
|
||||
}
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
if !rewriteToolNamesInValue(event, toolNameMap) {
|
||||
return line
|
||||
}
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
// 解析message_start获取input tokens(标准Claude API格式)
|
||||
var msgStart struct {
|
||||
@@ -3001,7 +3391,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -3022,6 +3412,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
if account.IsOAuth() {
|
||||
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -3059,6 +3452,24 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
return newBody
|
||||
}
|
||||
|
||||
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||
return body
|
||||
}
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
@@ -3224,6 +3635,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
if account.IsOAuth() {
|
||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
@@ -3412,6 +3828,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
if tokenType == "oauth" {
|
||||
applyClaudeOAuthHeaderDefaults(req, false)
|
||||
}
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
|
||||
@@ -24,13 +24,13 @@ var (
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
UserAgent: "claude-cli/2.1.2 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
StainlessPackageVersion: "0.70.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessArch: "x64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
StainlessRuntimeVersion: "v24.3.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
@@ -230,7 +230,7 @@ func generateUUIDFromSeed(seed string) string {
|
||||
}
|
||||
|
||||
// parseUserAgentVersion 解析user-agent版本号
|
||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
||||
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
// 匹配 xxx/x.y.z 格式
|
||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||
|
||||
Reference in New Issue
Block a user