fix(网关): 区分 Claude Code OAuth 适配
This commit is contained in:
@@ -707,6 +707,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
parsedReq, err := service.ParseGatewayRequest(body)
|
parsedReq, err := service.ParseGatewayRequest(body)
|
||||||
|
|||||||
@@ -9,11 +9,15 @@ const (
|
|||||||
BetaClaudeCode = "claude-code-20250219"
|
BetaClaudeCode = "claude-code-20250219"
|
||||||
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
|
||||||
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
|
||||||
|
BetaTokenCounting = "token-counting-2024-11-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
|
||||||
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||||
|
|
||||||
|
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
|
||||||
|
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
|
||||||
|
|
||||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ var (
|
|||||||
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
|
||||||
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
|
||||||
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||||
|
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
|
||||||
|
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
|
||||||
|
|
||||||
claudeToolNameOverrides = map[string]string{
|
claudeToolNameOverrides = map[string]string{
|
||||||
"bash": "Bash",
|
"bash": "Bash",
|
||||||
@@ -1941,6 +1943,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
|||||||
return claudeCliUserAgentRe.MatchString(userAgent)
|
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
|
||||||
|
if IsClaudeCodeClient(ctx) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if parsed == nil || c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
|
||||||
|
}
|
||||||
|
|
||||||
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||||
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
|
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
|
||||||
func systemIncludesClaudeCodePrompt(system any) bool {
|
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||||
@@ -2203,11 +2215,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
originalModel := reqModel
|
originalModel := reqModel
|
||||||
var toolNameMap map[string]string
|
var toolNameMap map[string]string
|
||||||
|
|
||||||
if account.IsOAuth() {
|
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||||
|
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||||
|
|
||||||
|
if shouldMimicClaudeCode {
|
||||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||||
if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||||
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
|
||||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||||
body = injectClaudeCodePrompt(body, parsed.System)
|
body = injectClaudeCodePrompt(body, parsed.System)
|
||||||
}
|
}
|
||||||
@@ -2257,7 +2271,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
retryStart := time.Now()
|
retryStart := time.Now()
|
||||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
// Capture upstream request body for ops retry of this attempt.
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2337,7 +2351,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// also downgrade tool_use/tool_result blocks to text.
|
// also downgrade tool_use/tool_result blocks to text.
|
||||||
|
|
||||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
|
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
@@ -2369,7 +2383,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
|
||||||
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
|
||||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
|
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||||
if buildErr2 == nil {
|
if buildErr2 == nil {
|
||||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr2 == nil {
|
if retryErr2 == nil {
|
||||||
@@ -2586,7 +2600,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
var firstTokenMs *int
|
var firstTokenMs *int
|
||||||
var clientDisconnect bool
|
var clientDisconnect bool
|
||||||
if reqStream {
|
if reqStream {
|
||||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
|
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() == "have error in stream" {
|
if err.Error() == "have error in stream" {
|
||||||
return nil, &UpstreamFailoverError{
|
return nil, &UpstreamFailoverError{
|
||||||
@@ -2599,7 +2613,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
firstTokenMs = streamResult.firstTokenMs
|
firstTokenMs = streamResult.firstTokenMs
|
||||||
clientDisconnect = streamResult.clientDisconnect
|
clientDisconnect = streamResult.clientDisconnect
|
||||||
} else {
|
} else {
|
||||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
|
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -2616,7 +2630,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||||
// 确定目标URL
|
// 确定目标URL
|
||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
@@ -2632,7 +2646,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
// OAuth账号:应用统一指纹
|
// OAuth账号:应用统一指纹
|
||||||
var fingerprint *Fingerprint
|
var fingerprint *Fingerprint
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -2685,12 +2699,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
if req.Header.Get("anthropic-version") == "" {
|
if req.Header.Get("anthropic-version") == "" {
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
}
|
}
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
applyClaudeOAuthHeaderDefaults(req, reqStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||||
@@ -3070,7 +3084,7 @@ type streamingResult struct {
|
|||||||
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
clientDisconnect bool // 客户端是否在流式传输过程中断开
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
|
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
|
||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
@@ -3163,7 +3177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
needModelReplace := originalModel != mappedModel
|
needModelReplace := originalModel != mappedModel
|
||||||
rewriteTools := account.IsOAuth()
|
rewriteTools := mimicClaudeCode
|
||||||
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -3327,6 +3341,37 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
|
||||||
|
if text == "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
|
||||||
|
submatches := toolNameFieldRe.FindStringSubmatch(match)
|
||||||
|
if len(submatches) < 2 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
name := submatches[1]
|
||||||
|
mapped := normalizeToolNameForOpenCode(name, toolNameMap)
|
||||||
|
if mapped == name {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return strings.Replace(match, name, mapped, 1)
|
||||||
|
})
|
||||||
|
output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
|
||||||
|
submatches := modelFieldRe.FindStringSubmatch(match)
|
||||||
|
if len(submatches) < 2 {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
model := submatches[1]
|
||||||
|
mapped := claude.DenormalizeModelID(model)
|
||||||
|
if mapped == model {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
return strings.Replace(match, model, mapped, 1)
|
||||||
|
})
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
|
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
|
||||||
if !sseDataRe.MatchString(line) {
|
if !sseDataRe.MatchString(line) {
|
||||||
return line
|
return line
|
||||||
@@ -3338,7 +3383,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
|
|||||||
|
|
||||||
var event map[string]any
|
var event map[string]any
|
||||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||||
return line
|
replaced := replaceToolNamesInText(data, toolNameMap)
|
||||||
|
if replaced == data {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
return "data: " + replaced
|
||||||
}
|
}
|
||||||
if !rewriteToolNamesInValue(event, toolNameMap) {
|
if !rewriteToolNamesInValue(event, toolNameMap) {
|
||||||
return line
|
return line
|
||||||
@@ -3391,7 +3440,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
|
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
|
||||||
// 更新5h窗口状态
|
// 更新5h窗口状态
|
||||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||||
|
|
||||||
@@ -3412,7 +3461,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
|||||||
if originalModel != mappedModel {
|
if originalModel != mappedModel {
|
||||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||||
}
|
}
|
||||||
if account.IsOAuth() {
|
if mimicClaudeCode {
|
||||||
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3458,7 +3507,11 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
|
|||||||
}
|
}
|
||||||
var resp map[string]any
|
var resp map[string]any
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
return body
|
replaced := replaceToolNamesInText(string(body), toolNameMap)
|
||||||
|
if replaced == string(body) {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return []byte(replaced)
|
||||||
}
|
}
|
||||||
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
if !rewriteToolNamesInValue(resp, toolNameMap) {
|
||||||
return body
|
return body
|
||||||
@@ -3635,7 +3688,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
reqModel := parsed.Model
|
reqModel := parsed.Model
|
||||||
|
|
||||||
if account.IsOAuth() {
|
isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
|
||||||
|
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||||
|
|
||||||
|
if shouldMimicClaudeCode {
|
||||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
||||||
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||||
}
|
}
|
||||||
@@ -3666,7 +3722,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游请求
|
// 构建上游请求
|
||||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
|
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||||
return err
|
return err
|
||||||
@@ -3699,7 +3755,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||||
|
|
||||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
@@ -3764,7 +3820,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
|
||||||
// 确定目标 URL
|
// 确定目标 URL
|
||||||
targetURL := claudeAPICountTokensURL
|
targetURL := claudeAPICountTokensURL
|
||||||
if account.Type == AccountTypeAPIKey {
|
if account.Type == AccountTypeAPIKey {
|
||||||
@@ -3779,7 +3835,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuth 账号:应用统一指纹和重写 userID
|
// OAuth 账号:应用统一指纹和重写 userID
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
accountUUID := account.GetExtraString("account_uuid")
|
accountUUID := account.GetExtraString("account_uuid")
|
||||||
@@ -3814,7 +3870,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OAuth 账号:应用指纹到请求头
|
// OAuth 账号:应用指纹到请求头
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
|
||||||
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
if fp != nil {
|
if fp != nil {
|
||||||
s.identityService.ApplyFingerprint(req, fp)
|
s.identityService.ApplyFingerprint(req, fp)
|
||||||
@@ -3828,13 +3884,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
if req.Header.Get("anthropic-version") == "" {
|
if req.Header.Get("anthropic-version") == "" {
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
}
|
}
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
applyClaudeOAuthHeaderDefaults(req, false)
|
applyClaudeOAuthHeaderDefaults(req, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" && mimicClaudeCode {
|
||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
|
||||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||||
if requestNeedsBetaFeatures(body) {
|
if requestNeedsBetaFeatures(body) {
|
||||||
|
|||||||
Reference in New Issue
Block a user