fix(网关): 补齐非 Claude Code OAuth 兼容
This commit is contained in:
@@ -15,6 +15,12 @@ const (
|
|||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
|
||||||
|
const MessageBetaHeaderNoTools = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
|
||||||
|
const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||||
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||||
|
|
||||||
|
|||||||
@@ -364,6 +364,22 @@ func (a *Account) GetExtraString(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetClaudeUserID() string {
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ var (
|
|||||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||||
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||||
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||||
|
toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
|
||||||
|
toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
|
||||||
|
opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
|
||||||
|
|
||||||
claudeToolNameOverrides = map[string]string{
|
claudeToolNameOverrides = map[string]string{
|
||||||
"bash": "Bash",
|
"bash": "Bash",
|
||||||
@@ -451,6 +454,22 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
||||||
|
if name == "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
stripped := stripToolPrefix(name)
|
||||||
|
if cache != nil {
|
||||||
|
if mapped, ok := cache[stripped]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mapped, ok := openCodeToolOverrides[stripped]; ok {
|
||||||
|
return mapped
|
||||||
|
}
|
||||||
|
return toSnakeCase(stripped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
@@ -459,10 +478,63 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
|
|||||||
return mapped
|
return mapped
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if mapped, ok := openCodeToolOverrides[name]; ok {
|
return name
|
||||||
return mapped
|
}
|
||||||
|
|
||||||
|
func sanitizeOpenCodeText(text string) string {
|
||||||
|
if text == "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
|
||||||
|
text = opencodeTextRe.ReplaceAllString(text, "Claude")
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeToolDescription(description string) string {
|
||||||
|
if description == "" {
|
||||||
|
return description
|
||||||
|
}
|
||||||
|
description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
|
||||||
|
description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
|
||||||
|
return sanitizeOpenCodeText(description)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
|
||||||
|
schema, ok := inputSchema.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
properties, ok := schema["properties"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newProperties := make(map[string]any, len(properties))
|
||||||
|
for key, value := range properties {
|
||||||
|
snakeKey := toSnakeCase(key)
|
||||||
|
newProperties[snakeKey] = value
|
||||||
|
if snakeKey != key && cache != nil {
|
||||||
|
cache[snakeKey] = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schema["properties"] = newProperties
|
||||||
|
|
||||||
|
if required, ok := schema["required"].([]any); ok {
|
||||||
|
newRequired := make([]any, 0, len(required))
|
||||||
|
for _, item := range required {
|
||||||
|
name, ok := item.(string)
|
||||||
|
if !ok {
|
||||||
|
newRequired = append(newRequired, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snakeName := toSnakeCase(name)
|
||||||
|
newRequired = append(newRequired, snakeName)
|
||||||
|
if snakeName != name && cache != nil {
|
||||||
|
cache[snakeName] = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schema["required"] = newRequired
|
||||||
}
|
}
|
||||||
return toSnakeCase(name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func stripCacheControlFromSystemBlocks(system any) bool {
|
func stripCacheControlFromSystemBlocks(system any) bool {
|
||||||
@@ -479,9 +551,6 @@ func stripCacheControlFromSystemBlocks(system any) bool {
|
|||||||
if _, exists := block["cache_control"]; !exists {
|
if _, exists := block["cache_control"]; !exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(block, "cache_control")
|
delete(block, "cache_control")
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
@@ -499,6 +568,34 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
|
|
||||||
toolNameMap := make(map[string]string)
|
toolNameMap := make(map[string]string)
|
||||||
|
|
||||||
|
if system, ok := req["system"]; ok {
|
||||||
|
switch v := system.(type) {
|
||||||
|
case string:
|
||||||
|
sanitized := sanitizeOpenCodeText(v)
|
||||||
|
if sanitized != v {
|
||||||
|
req["system"] = sanitized
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
block, ok := item.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if blockType, _ := block["type"].(string); blockType != "text" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text, ok := block["text"].(string)
|
||||||
|
if !ok || text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sanitized := sanitizeOpenCodeText(text)
|
||||||
|
if sanitized != text {
|
||||||
|
block["text"] = sanitized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rawModel, ok := req["model"].(string); ok {
|
if rawModel, ok := req["model"].(string); ok {
|
||||||
normalized := claude.NormalizeModelID(rawModel)
|
normalized := claude.NormalizeModelID(rawModel)
|
||||||
if normalized != rawModel {
|
if normalized != rawModel {
|
||||||
@@ -521,6 +618,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
toolMap["name"] = normalized
|
toolMap["name"] = normalized
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if desc, ok := toolMap["description"].(string); ok {
|
||||||
|
sanitized := sanitizeToolDescription(desc)
|
||||||
|
if sanitized != desc {
|
||||||
|
toolMap["description"] = sanitized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if schema, ok := toolMap["input_schema"]; ok {
|
||||||
|
normalizeToolInputSchema(schema, toolNameMap)
|
||||||
|
}
|
||||||
tools[idx] = toolMap
|
tools[idx] = toolMap
|
||||||
}
|
}
|
||||||
req["tools"] = tools
|
req["tools"] = tools
|
||||||
@@ -532,13 +638,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
normalized = name
|
normalized = name
|
||||||
}
|
}
|
||||||
if toolMap, ok := value.(map[string]any); ok {
|
if toolMap, ok := value.(map[string]any); ok {
|
||||||
if toolName, ok := toolMap["name"].(string); ok {
|
toolMap["name"] = normalized
|
||||||
mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
|
if desc, ok := toolMap["description"].(string); ok {
|
||||||
if mappedName != "" && mappedName != toolName {
|
sanitized := sanitizeToolDescription(desc)
|
||||||
toolMap["name"] = mappedName
|
if sanitized != desc {
|
||||||
|
toolMap["description"] = sanitized
|
||||||
}
|
}
|
||||||
} else if normalized != name {
|
}
|
||||||
toolMap["name"] = normalized
|
if schema, ok := toolMap["input_schema"]; ok {
|
||||||
|
normalizeToolInputSchema(schema, toolNameMap)
|
||||||
}
|
}
|
||||||
normalizedTools[normalized] = toolMap
|
normalizedTools[normalized] = toolMap
|
||||||
continue
|
continue
|
||||||
@@ -611,7 +719,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
|
||||||
if parsed == nil || fp == nil || fp.ClientID == "" {
|
if parsed == nil || account == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if parsed.MetadataUserID != "" {
|
if parsed.MetadataUserID != "" {
|
||||||
@@ -621,13 +729,22 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
|
|||||||
if accountUUID == "" {
|
if accountUUID == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userID := strings.TrimSpace(account.GetClaudeUserID())
|
||||||
|
if userID == "" && fp != nil {
|
||||||
|
userID = fp.ClientID
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
sessionHash := s.GenerateSessionHash(parsed)
|
sessionHash := s.GenerateSessionHash(parsed)
|
||||||
sessionID := uuid.NewString()
|
sessionID := uuid.NewString()
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
|
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
|
||||||
sessionID = generateSessionUUID(seed)
|
sessionID = generateSessionUUID(seed)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
|
return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSessionUUID(seed string) string {
|
func generateSessionUUID(seed string) string {
|
||||||
@@ -2213,7 +2330,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||||
if tokenType == "oauth" && mimicClaudeCode {
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
if requestHasTools(body) {
|
||||||
|
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
|
||||||
|
} else {
|
||||||
|
req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
|
||||||
|
}
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
if requestNeedsBetaFeatures(body) {
|
if requestNeedsBetaFeatures(body) {
|
||||||
@@ -2284,6 +2405,20 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func requestHasTools(body []byte) bool {
|
||||||
|
tools := gjson.GetBytes(body, "tools")
|
||||||
|
if !tools.Exists() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if tools.IsArray() {
|
||||||
|
return len(tools.Array()) > 0
|
||||||
|
}
|
||||||
|
if tools.IsObject() {
|
||||||
|
return len(tools.Map()) > 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||||
modelID := gjson.GetBytes(body, "model").String()
|
modelID := gjson.GetBytes(body, "model").String()
|
||||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
@@ -2817,6 +2952,45 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
|
|||||||
return "data: " + string(newData)
|
return "data: " + string(newData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
changed := false
|
||||||
|
rewritten := make(map[string]any, len(v))
|
||||||
|
for key, item := range v {
|
||||||
|
newKey := normalizeParamNameForOpenCode(key, cache)
|
||||||
|
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||||
|
if childChanged {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if newKey != key {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
rewritten[newKey] = newItem
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return rewritten, true
|
||||||
|
case []any:
|
||||||
|
changed := false
|
||||||
|
rewritten := make([]any, len(v))
|
||||||
|
for idx, item := range v {
|
||||||
|
newItem, childChanged := rewriteParamKeysInValue(item, cache)
|
||||||
|
if childChanged {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
rewritten[idx] = newItem
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return rewritten, true
|
||||||
|
default:
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
@@ -2829,6 +3003,15 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
|||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if input, ok := v["input"].(map[string]any); ok {
|
||||||
|
rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
|
||||||
|
if inputChanged {
|
||||||
|
if m, ok := rewrittenInput.(map[string]any); ok {
|
||||||
|
v["input"] = m
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if rewriteToolNamesInValue(item, toolNameMap) {
|
if rewriteToolNamesInValue(item, toolNameMap) {
|
||||||
@@ -2877,6 +3060,15 @@ func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
|||||||
}
|
}
|
||||||
return strings.Replace(match, model, mapped, 1)
|
return strings.Replace(match, model, mapped, 1)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for mapped, original := range toolNameMap {
|
||||||
|
if mapped == "" || original == "" || mapped == original {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
|
||||||
|
output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
|
||||||
|
}
|
||||||
|
|
||||||
return output
|
return output
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2889,22 +3081,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
|
|||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
|
|
||||||
var event map[string]any
|
replaced := replaceToolNamesInText(data, toolNameMap)
|
||||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
if replaced == data {
|
||||||
replaced := replaceToolNamesInText(data, toolNameMap)
|
|
||||||
if replaced == data {
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
return "data: " + replaced
|
|
||||||
}
|
|
||||||
if !rewriteToolNamesInValue(event, toolNameMap) {
|
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
newData, err := json.Marshal(event)
|
return "data: " + replaced
|
||||||
if err != nil {
|
|
||||||
return line
|
|
||||||
}
|
|
||||||
return "data: " + string(newData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||||
|
|||||||
Reference in New Issue
Block a user