fix(网关): 对齐 Claude OAuth 请求适配
This commit is contained in:
@@ -25,15 +25,15 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
|||||||
|
|
||||||
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
// DefaultHeaders 是 Claude Code 客户端默认请求头。
|
||||||
var DefaultHeaders = map[string]string{
|
var DefaultHeaders = map[string]string{
|
||||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
"User-Agent": "claude-cli/2.1.2 (external, cli)",
|
||||||
"X-Stainless-Lang": "js",
|
"X-Stainless-Lang": "js",
|
||||||
"X-Stainless-Package-Version": "0.52.0",
|
"X-Stainless-Package-Version": "0.70.0",
|
||||||
"X-Stainless-OS": "Linux",
|
"X-Stainless-OS": "Linux",
|
||||||
"X-Stainless-Arch": "x64",
|
"X-Stainless-Arch": "x64",
|
||||||
"X-Stainless-Runtime": "node",
|
"X-Stainless-Runtime": "node",
|
||||||
"X-Stainless-Runtime-Version": "v22.14.0",
|
"X-Stainless-Runtime-Version": "v24.3.0",
|
||||||
"X-Stainless-Retry-Count": "0",
|
"X-Stainless-Retry-Count": "0",
|
||||||
"X-Stainless-Timeout": "60",
|
"X-Stainless-Timeout": "600",
|
||||||
"X-App": "cli",
|
"X-App": "cli",
|
||||||
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
"Anthropic-Dangerous-Direct-Browser-Access": "true",
|
||||||
}
|
}
|
||||||
@@ -79,3 +79,39 @@ func DefaultModelIDs() []string {
|
|||||||
|
|
||||||
// DefaultTestModel 测试时使用的默认模型
|
// DefaultTestModel 测试时使用的默认模型
|
||||||
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
const DefaultTestModel = "claude-sonnet-4-5-20250929"
|
||||||
|
|
||||||
|
// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
|
||||||
|
var ModelIDOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||||
|
"claude-opus-4-5": "claude-opus-4-5-20251101",
|
||||||
|
"claude-haiku-4-5": "claude-haiku-4-5-20251001",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
|
||||||
|
var ModelIDReverseOverrides = map[string]string{
|
||||||
|
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
|
||||||
|
"claude-opus-4-5-20251101": "claude-opus-4-5",
|
||||||
|
"claude-haiku-4-5-20251001": "claude-haiku-4-5",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeModelID 根据 Claude OAuth 规则映射模型
|
||||||
|
func NormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// DenormalizeModelID 将上游模型 ID 转换为短名
|
||||||
|
func DenormalizeModelID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
if mapped, ok := ModelIDReverseOverrides[id]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,12 +18,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
|
|
||||||
@@ -60,6 +62,36 @@ var (
|
|||||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||||
|
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||||
|
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||||
|
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||||
|
|
||||||
|
claudeToolNameOverrides = map[string]string{
|
||||||
|
"bash": "Bash",
|
||||||
|
"read": "Read",
|
||||||
|
"edit": "Edit",
|
||||||
|
"write": "Write",
|
||||||
|
"task": "Task",
|
||||||
|
"glob": "Glob",
|
||||||
|
"grep": "Grep",
|
||||||
|
"webfetch": "WebFetch",
|
||||||
|
"websearch": "WebSearch",
|
||||||
|
"todowrite": "TodoWrite",
|
||||||
|
"question": "AskUserQuestion",
|
||||||
|
}
|
||||||
|
openCodeToolOverrides = map[string]string{
|
||||||
|
"Bash": "bash",
|
||||||
|
"Read": "read",
|
||||||
|
"Edit": "edit",
|
||||||
|
"Write": "write",
|
||||||
|
"Task": "task",
|
||||||
|
"Glob": "glob",
|
||||||
|
"Grep": "grep",
|
||||||
|
"WebFetch": "webfetch",
|
||||||
|
"WebSearch": "websearch",
|
||||||
|
"TodoWrite": "todowrite",
|
||||||
|
"AskUserQuestion": "question",
|
||||||
|
}
|
||||||
|
|
||||||
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
|
||||||
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
|
||||||
@@ -365,6 +397,268 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type claudeOAuthNormalizeOptions struct {
|
||||||
|
injectMetadata bool
|
||||||
|
metadataUserID string
|
||||||
|
stripSystemCacheControl bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripToolPrefix(value string) string {
|
||||||
|
if value == "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return toolPrefixRe.ReplaceAllString(value, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toPascalCase(value string) string {
|
||||||
|
if value == "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
|
||||||
|
tokens := make([]string, 0)
|
||||||
|
for _, token := range strings.Fields(normalized) {
|
||||||
|
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
|
||||||
|
parts := strings.Fields(expanded)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
tokens = append(tokens, parts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
var builder strings.Builder
|
||||||
|
for _, token := range tokens {
|
||||||
|
lower := strings.ToLower(token)
|
||||||
|
if lower == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
runes := []rune(lower)
|
||||||
|
runes[0] = unicode.ToUpper(runes[0])
|
||||||
|
builder.WriteString(string(runes))
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSnakeCase(value string) string {
|
||||||
|
if value == "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
|
||||||
|
output = toolNameBoundaryRe.ReplaceAllString(output, "_")
|
||||||
|
output = strings.Trim(output, "_")
|
||||||
|
return strings.ToLower(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||||
|
if name == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
stripped := stripToolPrefix(name)
|
||||||
|
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
|
||||||
|
if !ok {
|
||||||
|
mapped = toPascalCase(stripped)
|
||||||
|
}
|
||||||
|
if mapped != "" && cache != nil && mapped != stripped {
|
||||||
|
cache[mapped] = stripped
|
||||||
|
}
|
||||||
|
if mapped == "" {
|
||||||
|
return stripped
|
||||||
|
}
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||||
|
if name == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
if cache != nil {
|
||||||
|
if mapped, ok := cache[name]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mapped, ok := openCodeToolOverrides[name]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return toSnakeCase(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||||
|
blocks, ok := system.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
for _, item := range blocks {
|
||||||
|
block, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := block["cache_control"]; !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(block, "cache_control")
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body, modelID, nil
|
||||||
|
}
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return body, modelID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
toolNameMap := make(map[string]string)
|
||||||
|
|
||||||
|
if rawModel, ok := req["model"].(string); ok {
|
||||||
|
normalized := claude.NormalizeModelID(rawModel)
|
||||||
|
if normalized != rawModel {
|
||||||
|
req["model"] = normalized
|
||||||
|
modelID = normalized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawTools, exists := req["tools"]; exists {
|
||||||
|
switch tools := rawTools.(type) {
|
||||||
|
case []any:
|
||||||
|
for idx, tool := range tools {
|
||||||
|
toolMap, ok := tool.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name, ok := toolMap["name"].(string); ok {
|
||||||
|
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||||
|
if normalized != "" && normalized != name {
|
||||||
|
toolMap["name"] = normalized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tools[idx] = toolMap
|
||||||
|
}
|
||||||
|
req["tools"] = tools
|
||||||
|
case map[string]any:
|
||||||
|
normalizedTools := make(map[string]any, len(tools))
|
||||||
|
for name, value := range tools {
|
||||||
|
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||||
|
if normalized == "" {
|
||||||
|
normalized = name
|
||||||
|
}
|
||||||
|
if toolMap, ok := value.(map[string]any); ok {
|
||||||
|
if toolName, ok := toolMap["name"].(string); ok {
|
||||||
|
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
|
||||||
|
if mappedName != "" && mappedName != toolName {
|
||||||
|
toolMap["name"] = mappedName
|
||||||
|
}
|
||||||
|
} else if normalized != name {
|
||||||
|
toolMap["name"] = normalized
|
||||||
|
}
|
||||||
|
normalizedTools[normalized] = toolMap
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
normalizedTools[normalized] = value
|
||||||
|
}
|
||||||
|
req["tools"] = normalizedTools
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req["tools"] = []any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages, ok := req["messages"].([]any); ok {
|
||||||
|
for _, msg := range messages {
|
||||||
|
msgMap, ok := msg.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content, ok := msgMap["content"].([]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, block := range content {
|
||||||
|
blockMap, ok := block.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if name, ok := blockMap["name"].(string); ok {
|
||||||
|
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||||
|
if normalized != "" && normalized != name {
|
||||||
|
blockMap["name"] = normalized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.stripSystemCacheControl {
|
||||||
|
if system, ok := req["system"]; ok {
|
||||||
|
_ = stripCacheControlFromSystemBlocks(system)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.injectMetadata && opts.metadataUserID != "" {
|
||||||
|
metadata, ok := req["metadata"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
metadata = map[string]any{}
|
||||||
|
req["metadata"] = metadata
|
||||||
|
}
|
||||||
|
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
|
||||||
|
metadata["user_id"] = opts.metadataUserID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := req["temperature"]; ok {
|
||||||
|
delete(req, "temperature")
|
||||||
|
}
|
||||||
|
if _, ok := req["tool_choice"]; ok {
|
||||||
|
delete(req, "tool_choice")
|
||||||
|
}
|
||||||
|
|
||||||
|
newBody, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return body, modelID, toolNameMap
|
||||||
|
}
|
||||||
|
return newBody, modelID, toolNameMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||||
|
if parsed == nil || fp == nil || fp.ClientID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if parsed.MetadataUserID != "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
accountUUID := account.GetExtraString("account_uuid")
|
||||||
|
if accountUUID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sessionHash := s.GenerateSessionHash(parsed)
|
||||||
|
sessionID := uuid.NewString()
|
||||||
|
if sessionHash != "" {
|
||||||
|
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
|
||||||
|
sessionID = generateSessionUUID(seed)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSessionUUID(seed string) string {
|
||||||
|
if seed == "" {
|
||||||
|
return uuid.NewString()
|
||||||
|
}
|
||||||
|
hash := sha256.Sum256([]byte(seed))
|
||||||
|
bytes := hash[:16]
|
||||||
|
bytes[6] = (bytes[6] & 0x0f) | 0x40
|
||||||
|
bytes[8] = (bytes[8] & 0x3f) | 0x80
|
||||||
|
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||||
|
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||||
|
}
|
||||||
|
|
||||||
// SelectAccount 选择账号(粘性会话+优先级)
|
// SelectAccount 选择账号(粘性会话+优先级)
|
||||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||||
@@ -1906,21 +2200,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
reqStream := parsed.Stream
|
reqStream := parsed.Stream
|
||||||
|
originalModel := reqModel
|
||||||
|
var toolNameMap map[string]string
|
||||||
|
|
||||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
if account.IsOAuth() {
|
||||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||||
if account.IsOAuth() &&
|
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||||
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||||
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||||
body = injectClaudeCodePrompt(body, parsed.System)
|
body = injectClaudeCodePrompt(body, parsed.System)
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||||
|
if s.identityService != nil {
|
||||||
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
|
if err == nil && fp != nil {
|
||||||
|
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
|
||||||
|
normalizeOpts.injectMetadata = true
|
||||||
|
normalizeOpts.metadataUserID = metadataUserID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||||
body = enforceCacheControlLimit(body)
|
body = enforceCacheControlLimit(body)
|
||||||
|
|
||||||
// 应用模型映射(仅对apikey类型账号)
|
// 应用模型映射(仅对apikey类型账号)
|
||||||
originalModel := reqModel
|
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
if mappedModel != reqModel {
|
if mappedModel != reqModel {
|
||||||
@@ -1948,10 +2257,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
|
||||||
// Capture upstream request body for ops retry of this attempt.
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -2029,7 +2337,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// also downgrade tool_use/tool_result blocks to text.
|
// also downgrade tool_use/tool_result blocks to text.
|
||||||
|
|
||||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
@@ -2061,7 +2369,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
|
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
|
||||||
if buildErr2 == nil {
|
if buildErr2 == nil {
|
||||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr2 == nil {
|
if retryErr2 == nil {
|
||||||
@@ -2278,7 +2586,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if reqStream {
|
if reqStream {
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "have error in stream" {
|
if err.Error() == "have error in stream" {
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
@@ -2291,7 +2599,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
firstTokenMs = streamResult.firstTokenMs
|
firstTokenMs = streamResult.firstTokenMs
|
||||||
clientDisconnect = streamResult.clientDisconnect
|
clientDisconnect = streamResult.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -2308,7 +2616,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
|
||||||
// 确定目标URL
|
// 确定目标URL
|
||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
@@ -2377,6 +2685,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
if req.Header.Get("anthropic-version") == "" {
|
if req.Header.Get("anthropic-version") == "" {
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
}
|
}
|
||||||
|
if tokenType == "oauth" {
|
||||||
|
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||||
|
}
|
||||||
|
|
||||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
@@ -2459,6 +2770,26 @@ func defaultAPIKeyBetaHeader(body []byte) string {
|
|||||||
return claude.APIKeyBetaHeader
|
return claude.APIKeyBetaHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
|
||||||
|
if req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Header.Get("accept") == "" {
|
||||||
|
req.Header.Set("accept", "application/json")
|
||||||
|
}
|
||||||
|
for key, value := range claude.DefaultHeaders {
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if req.Header.Get(key) == "" {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isStream && req.Header.Get("x-stainless-helper-method") == "" {
|
||||||
|
req.Header.Set("x-stainless-helper-method", "stream")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func truncateForLog(b []byte, maxBytes int) string {
|
func truncateForLog(b []byte, maxBytes int) string {
|
||||||
if maxBytes <= 0 {
|
if maxBytes <= 0 {
|
||||||
maxBytes = 2048
|
maxBytes = 2048
|
||||||
@@ -2739,7 +3070,7 @@ type streamingResult struct {
|
|||||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
|
||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
@@ -2832,6 +3163,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
|
rewriteTools := account.IsOAuth()
|
||||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -2873,11 +3205,14 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
var data string
|
var data string
|
||||||
if sseDataRe.MatchString(line) {
|
if sseDataRe.MatchString(line) {
|
||||||
data = sseDataRe.ReplaceAllString(line, "")
|
|
||||||
// 如果有模型映射,替换响应中的model字段
|
// 如果有模型映射,替换响应中的model字段
|
||||||
if needModelReplace {
|
if needModelReplace {
|
||||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
if rewriteTools {
|
||||||
|
line = s.replaceToolNamesInSSELine(line, toolNameMap)
|
||||||
|
}
|
||||||
|
data = sseDataRe.ReplaceAllString(line, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入客户端(统一处理 data 行和非 data 行)
|
// 写入客户端(统一处理 data 行和非 data 行)
|
||||||
@@ -2960,6 +3295,61 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
|||||||
return "data: " + string(newData)
|
return "data: " + string(newData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
changed := false
|
||||||
|
if blockType, _ := v["type"].(string); blockType == "tool_use" {
|
||||||
|
if name, ok := v["name"].(string); ok {
|
||||||
|
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||||
|
if mapped != name {
|
||||||
|
v["name"] = mapped
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, item := range v {
|
||||||
|
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
case []any:
|
||||||
|
changed := false
|
||||||
|
for _, item := range v {
|
||||||
|
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
|
||||||
|
if !sseDataRe.MatchString(line) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
data := sseDataRe.ReplaceAllString(line, "")
|
||||||
|
if data == "" || data == "[DONE]" {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
var event map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
if !rewriteToolNamesInValue(event, toolNameMap) {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
newData, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
return "data: " + string(newData)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||||
// 解析message_start获取input tokens(标准Claude API格式)
|
// 解析message_start获取input tokens(标准Claude API格式)
|
||||||
var msgStart struct {
|
var msgStart struct {
|
||||||
@@ -3001,7 +3391,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
|
||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
@@ -3022,6 +3412,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
|
if account.IsOAuth() {
|
||||||
|
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||||
|
}
|
||||||
|
|
||||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||||
|
|
||||||
@@ -3059,6 +3452,24 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
newBody, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return newBody
|
||||||
|
}
|
||||||
|
|
||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
@@ -3224,6 +3635,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
|
|
||||||
|
if account.IsOAuth() {
|
||||||
|
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||||
|
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
|
}
|
||||||
|
|
||||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||||
if account.Platform == PlatformAntigravity {
|
if account.Platform == PlatformAntigravity {
|
||||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||||
@@ -3412,6 +3828,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if req.Header.Get("anthropic-version") == "" {
|
if req.Header.Get("anthropic-version") == "" {
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
}
|
}
|
||||||
|
if tokenType == "oauth" {
|
||||||
|
applyClaudeOAuthHeaderDefaults(req, false)
|
||||||
|
}
|
||||||
|
|
||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
|
|||||||
@@ -24,13 +24,13 @@ var (
|
|||||||
|
|
||||||
// 默认指纹值(当客户端未提供时使用)
|
// 默认指纹值(当客户端未提供时使用)
|
||||||
var defaultFingerprint = Fingerprint{
|
var defaultFingerprint = Fingerprint{
|
||||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
UserAgent: "claude-cli/2.1.2 (external, cli)",
|
||||||
StainlessLang: "js",
|
StainlessLang: "js",
|
||||||
StainlessPackageVersion: "0.52.0",
|
StainlessPackageVersion: "0.70.0",
|
||||||
StainlessOS: "Linux",
|
StainlessOS: "Linux",
|
||||||
StainlessArch: "x64",
|
StainlessArch: "x64",
|
||||||
StainlessRuntime: "node",
|
StainlessRuntime: "node",
|
||||||
StainlessRuntimeVersion: "v22.14.0",
|
StainlessRuntimeVersion: "v24.3.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fingerprint represents account fingerprint data
|
// Fingerprint represents account fingerprint data
|
||||||
@@ -230,7 +230,7 @@ func generateUUIDFromSeed(seed string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseUserAgentVersion 解析user-agent版本号
|
// parseUserAgentVersion 解析user-agent版本号
|
||||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
// 例如:claude-cli/2.1.2 -> (2, 1, 2)
|
||||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||||
// 匹配 xxx/x.y.z 格式
|
// 匹配 xxx/x.y.z 格式
|
||||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||||
|
|||||||
Reference in New Issue
Block a user