fix(网关): 补齐非 Claude Code OAuth 兼容

This commit is contained in:
cyhhao
2026-01-16 00:41:29 +08:00
parent b8c48fb477
commit 0962ba43c0
3 changed files with 232 additions and 29 deletions

View File

@@ -15,6 +15,12 @@ const (
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
const MessageBetaHeaderNoTools = BetaOAuth + "," + BetaInterleavedThinking
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting

View File

@@ -381,6 +381,22 @@ func (a *Account) GetExtraString(key string) string {
return ""
}
func (a *Account) GetClaudeUserID() string {
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
return v
}
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
return v
}
return ""
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false

View File

@@ -67,6 +67,9 @@ var (
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -470,6 +473,22 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
stripped := stripToolPrefix(name)
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
return toSnakeCase(stripped)
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
@@ -478,10 +497,63 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return mapped
}
}
if mapped, ok := openCodeToolOverrides[name]; ok {
return mapped
return name
}
func sanitizeOpenCodeText(text string) string {
if text == "" {
return text
}
text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
text = opencodeTextRe.ReplaceAllString(text, "Claude")
return text
}
func sanitizeToolDescription(description string) string {
if description == "" {
return description
}
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
return sanitizeOpenCodeText(description)
}
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
schema, ok := inputSchema.(map[string]any)
if !ok {
return
}
properties, ok := schema["properties"].(map[string]any)
if !ok {
return
}
newProperties := make(map[string]any, len(properties))
for key, value := range properties {
snakeKey := toSnakeCase(key)
newProperties[snakeKey] = value
if snakeKey != key && cache != nil {
cache[snakeKey] = key
}
}
schema["properties"] = newProperties
if required, ok := schema["required"].([]any); ok {
newRequired := make([]any, 0, len(required))
for _, item := range required {
name, ok := item.(string)
if !ok {
newRequired = append(newRequired, item)
continue
}
snakeName := toSnakeCase(name)
newRequired = append(newRequired, snakeName)
if snakeName != name && cache != nil {
cache[snakeName] = name
}
}
schema["required"] = newRequired
}
return toSnakeCase(name)
}
func stripCacheControlFromSystemBlocks(system any) bool {
@@ -498,9 +570,6 @@ func stripCacheControlFromSystemBlocks(system any) bool {
if _, exists := block["cache_control"]; !exists {
continue
}
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
continue
}
delete(block, "cache_control")
changed = true
}
@@ -518,6 +587,34 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolNameMap := make(map[string]string)
if system, ok := req["system"]; ok {
switch v := system.(type) {
case string:
sanitized := sanitizeOpenCodeText(v)
if sanitized != v {
req["system"] = sanitized
}
case []any:
for _, item := range v {
block, ok := item.(map[string]any)
if !ok {
continue
}
if blockType, _ := block["type"].(string); blockType != "text" {
continue
}
text, ok := block["text"].(string)
if !ok || text == "" {
continue
}
sanitized := sanitizeOpenCodeText(text)
if sanitized != text {
block["text"] = sanitized
}
}
}
}
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
@@ -540,6 +637,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolMap["name"] = normalized
}
}
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
tools[idx] = toolMap
}
req["tools"] = tools
@@ -551,13 +657,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
if toolName, ok := toolMap["name"].(string); ok {
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
if mappedName != "" && mappedName != toolName {
toolMap["name"] = mappedName
toolMap["name"] = normalized
if desc, ok := toolMap["description"].(string); ok {
sanitized := sanitizeToolDescription(desc)
if sanitized != desc {
toolMap["description"] = sanitized
}
} else if normalized != name {
toolMap["name"] = normalized
}
if schema, ok := toolMap["input_schema"]; ok {
normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
@@ -630,7 +738,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
if parsed == nil || fp == nil || fp.ClientID == "" {
if parsed == nil || account == nil {
return ""
}
if parsed.MetadataUserID != "" {
@@ -640,13 +748,22 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
if accountUUID == "" {
return ""
}
userID := strings.TrimSpace(account.GetClaudeUserID())
if userID == "" && fp != nil {
userID = fp.ClientID
}
if userID == "" {
return ""
}
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
func generateSessionUUID(seed string) string {
@@ -2705,7 +2822,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" && mimicClaudeCode {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
if requestHasTools(body) {
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
} else {
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
@@ -2776,6 +2897,20 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func requestHasTools(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() {
return false
}
if tools.IsArray() {
return len(tools.Array()) > 0
}
if tools.IsObject() {
return len(tools.Map()) > 0
}
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
@@ -3309,6 +3444,45 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData)
}
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
switch v := value.(type) {
case map[string]any:
changed := false
rewritten := make(map[string]any, len(v))
for key, item := range v {
newKey := normalizeParamNameForOpenCode(key, cache)
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
if newKey != key {
changed = true
}
rewritten[newKey] = newItem
}
if !changed {
return value, false
}
return rewritten, true
case []any:
changed := false
rewritten := make([]any, len(v))
for idx, item := range v {
newItem, childChanged := rewriteParamKeysInValue(item, cache)
if childChanged {
changed = true
}
rewritten[idx] = newItem
}
if !changed {
return value, false
}
return rewritten, true
default:
return value, false
}
}
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
@@ -3321,6 +3495,15 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
changed = true
}
}
if input, ok := v["input"].(map[string]any); ok {
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
if inputChanged {
if m, ok := rewrittenInput.(map[string]any); ok {
v["input"] = m
changed = true
}
}
}
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
@@ -3369,6 +3552,15 @@ func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
}
return strings.Replace(match, model, mapped, 1)
})
for mapped, original := range toolNameMap {
if mapped == "" || original == "" || mapped == original {
continue
}
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
}
return output
}
@@ -3381,22 +3573,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
replaced := replaceToolNamesInText(data, toolNameMap)
if replaced == data {
return line
}
return "data: " + replaced
}
if !rewriteToolNamesInValue(event, toolNameMap) {
replaced := replaceToolNamesInText(data, toolNameMap)
if replaced == data {
return line
}
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
return "data: " + replaced
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {