fix(网关): 补齐非 Claude Code OAuth 兼容
This commit is contained in:
@@ -15,6 +15,12 @@ const (
|
||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||
const MessageBetaHeaderNoTools = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||
|
||||
|
||||
@@ -381,6 +381,22 @@ func (a *Account) GetExtraString(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) GetClaudeUserID() string {
|
||||
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
|
||||
@@ -67,6 +67,9 @@ var (
|
||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
|
||||
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
|
||||
opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
|
||||
|
||||
claudeToolNameOverrides = map[string]string{
|
||||
"bash": "Bash",
|
||||
@@ -470,6 +473,22 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
||||
}
|
||||
|
||||
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
stripped := stripToolPrefix(name)
|
||||
if cache != nil {
|
||||
if mapped, ok := cache[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
if mapped, ok := openCodeToolOverrides[stripped]; ok {
|
||||
return mapped
|
||||
}
|
||||
return toSnakeCase(stripped)
|
||||
}
|
||||
|
||||
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
|
||||
if name == "" {
|
||||
return name
|
||||
}
|
||||
@@ -478,10 +497,63 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||
return mapped
|
||||
}
|
||||
}
|
||||
if mapped, ok := openCodeToolOverrides[name]; ok {
|
||||
return mapped
|
||||
return name
|
||||
}
|
||||
|
||||
func sanitizeOpenCodeText(text string) string {
|
||||
if text == "" {
|
||||
return text
|
||||
}
|
||||
text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
|
||||
text = opencodeTextRe.ReplaceAllString(text, "Claude")
|
||||
return text
|
||||
}
|
||||
|
||||
func sanitizeToolDescription(description string) string {
|
||||
if description == "" {
|
||||
return description
|
||||
}
|
||||
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
|
||||
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
|
||||
return sanitizeOpenCodeText(description)
|
||||
}
|
||||
|
||||
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
|
||||
schema, ok := inputSchema.(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
properties, ok := schema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
newProperties := make(map[string]any, len(properties))
|
||||
for key, value := range properties {
|
||||
snakeKey := toSnakeCase(key)
|
||||
newProperties[snakeKey] = value
|
||||
if snakeKey != key && cache != nil {
|
||||
cache[snakeKey] = key
|
||||
}
|
||||
}
|
||||
schema["properties"] = newProperties
|
||||
|
||||
if required, ok := schema["required"].([]any); ok {
|
||||
newRequired := make([]any, 0, len(required))
|
||||
for _, item := range required {
|
||||
name, ok := item.(string)
|
||||
if !ok {
|
||||
newRequired = append(newRequired, item)
|
||||
continue
|
||||
}
|
||||
snakeName := toSnakeCase(name)
|
||||
newRequired = append(newRequired, snakeName)
|
||||
if snakeName != name && cache != nil {
|
||||
cache[snakeName] = name
|
||||
}
|
||||
}
|
||||
schema["required"] = newRequired
|
||||
}
|
||||
return toSnakeCase(name)
|
||||
}
|
||||
|
||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
@@ -498,9 +570,6 @@ func stripCacheControlFromSystemBlocks(system any) bool {
|
||||
if _, exists := block["cache_control"]; !exists {
|
||||
continue
|
||||
}
|
||||
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
continue
|
||||
}
|
||||
delete(block, "cache_control")
|
||||
changed = true
|
||||
}
|
||||
@@ -518,6 +587,34 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
|
||||
toolNameMap := make(map[string]string)
|
||||
|
||||
if system, ok := req["system"]; ok {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
sanitized := sanitizeOpenCodeText(v)
|
||||
if sanitized != v {
|
||||
req["system"] = sanitized
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
block, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if blockType, _ := block["type"].(string); blockType != "text" {
|
||||
continue
|
||||
}
|
||||
text, ok := block["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
sanitized := sanitizeOpenCodeText(text)
|
||||
if sanitized != text {
|
||||
block["text"] = sanitized
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rawModel, ok := req["model"].(string); ok {
|
||||
normalized := claude.NormalizeModelID(rawModel)
|
||||
if normalized != rawModel {
|
||||
@@ -540,6 +637,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
toolMap["name"] = normalized
|
||||
}
|
||||
}
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
}
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
req["tools"] = tools
|
||||
@@ -551,13 +657,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
normalized = name
|
||||
}
|
||||
if toolMap, ok := value.(map[string]any); ok {
|
||||
if toolName, ok := toolMap["name"].(string); ok {
|
||||
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
|
||||
if mappedName != "" && mappedName != toolName {
|
||||
toolMap["name"] = mappedName
|
||||
toolMap["name"] = normalized
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
}
|
||||
} else if normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
}
|
||||
normalizedTools[normalized] = toolMap
|
||||
continue
|
||||
@@ -630,7 +738,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||
if parsed == nil || fp == nil || fp.ClientID == "" {
|
||||
if parsed == nil || account == nil {
|
||||
return ""
|
||||
}
|
||||
if parsed.MetadataUserID != "" {
|
||||
@@ -640,13 +748,22 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
||||
if accountUUID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(account.GetClaudeUserID())
|
||||
if userID == "" && fp != nil {
|
||||
userID = fp.ClientID
|
||||
}
|
||||
if userID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionHash := s.GenerateSessionHash(parsed)
|
||||
sessionID := uuid.NewString()
|
||||
if sessionHash != "" {
|
||||
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
|
||||
sessionID = generateSessionUUID(seed)
|
||||
}
|
||||
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
|
||||
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
|
||||
}
|
||||
|
||||
func generateSessionUUID(seed string) string {
|
||||
@@ -2705,7 +2822,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" && mimicClaudeCode {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
if requestHasTools(body) {
|
||||
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
|
||||
} else {
|
||||
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
|
||||
}
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
@@ -2776,6 +2897,20 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func requestHasTools(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.Exists() {
|
||||
return false
|
||||
}
|
||||
if tools.IsArray() {
|
||||
return len(tools.Array()) > 0
|
||||
}
|
||||
if tools.IsObject() {
|
||||
return len(tools.Map()) > 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
@@ -3309,6 +3444,45 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
changed := false
|
||||
rewritten := make(map[string]any, len(v))
|
||||
for key, item := range v {
|
||||
newKey := normalizeParamNameForOpenCode(key, cache)
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
if newKey != key {
|
||||
changed = true
|
||||
}
|
||||
rewritten[newKey] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
case []any:
|
||||
changed := false
|
||||
rewritten := make([]any, len(v))
|
||||
for idx, item := range v {
|
||||
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||
if childChanged {
|
||||
changed = true
|
||||
}
|
||||
rewritten[idx] = newItem
|
||||
}
|
||||
if !changed {
|
||||
return value, false
|
||||
}
|
||||
return rewritten, true
|
||||
default:
|
||||
return value, false
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
switch v := value.(type) {
|
||||
case map[string]any:
|
||||
@@ -3321,6 +3495,15 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if input, ok := v["input"].(map[string]any); ok {
|
||||
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
|
||||
if inputChanged {
|
||||
if m, ok := rewrittenInput.(map[string]any); ok {
|
||||
v["input"] = m
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, item := range v {
|
||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||
@@ -3369,6 +3552,15 @@ func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
||||
}
|
||||
return strings.Replace(match, model, mapped, 1)
|
||||
})
|
||||
|
||||
for mapped, original := range toolNameMap {
|
||||
if mapped == "" || original == "" || mapped == original {
|
||||
continue
|
||||
}
|
||||
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
|
||||
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
@@ -3381,22 +3573,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
replaced := replaceToolNamesInText(data, toolNameMap)
|
||||
if replaced == data {
|
||||
return line
|
||||
}
|
||||
return "data: " + replaced
|
||||
}
|
||||
if !rewriteToolNamesInValue(event, toolNameMap) {
|
||||
replaced := replaceToolNamesInText(data, toolNameMap)
|
||||
if replaced == data {
|
||||
return line
|
||||
}
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
return "data: " + replaced
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
|
||||
Reference in New Issue
Block a user