feat(anthropic): 支持 API Key 自动透传并优化透传链路性能
- 新增 Anthropic API Key 自动透传开关与后端透传分支(仅替换认证) - 账号编辑页新增自动透传开关,默认关闭 - 优化透传性能:SSE usage 解析 gjson 快路径、减少请求体重复拷贝、优化流式写回与非流式 usage 解析 - 补充单元测试与 benchmark,确保 Claude OAuth 路径不受影响
This commit is contained in:
@@ -3041,6 +3041,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
@@ -3120,14 +3124,14 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -3491,6 +3495,538 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) forwardAnthropicAPIKeyPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
reqModel string,
|
||||
reqStream bool,
|
||||
startTime time.Time,
|
||||
) (*ForwardResult, error) {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenType != "apikey" {
|
||||
return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType)
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v",
|
||||
account.ID, account.Name, reqModel, reqStream)
|
||||
|
||||
if c != nil {
|
||||
c.Set("anthropic_passthrough", true)
|
||||
}
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
|
||||
upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: safeErr,
|
||||
})
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream request failed",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
|
||||
}
|
||||
|
||||
// 透传分支禁止 400 请求体降级重试(该重试会改写请求体)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetryAttempts {
|
||||
elapsed := time.Since(retryStart)
|
||||
if elapsed >= maxRetryElapsed {
|
||||
break
|
||||
}
|
||||
|
||||
delay := retryBackoffDelay(attempt)
|
||||
remaining := maxRetryElapsed - elapsed
|
||||
if delay > remaining {
|
||||
delay = remaining
|
||||
}
|
||||
if delay <= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "retry",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
|
||||
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
|
||||
if err := sleepWithContext(ctx, delay); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, errors.New("upstream request failed: empty response")
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "retry_exhausted_failover",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "failover",
|
||||
Message: extractUpstreamErrorMessage(respBody),
|
||||
Detail: func() string {
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
var clientDisconnect bool
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
clientDisconnect = streamResult.clientDisconnect
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if usage == nil {
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: reqModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ClientDisconnect: clientDisconnect,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
) (*http.Request, error) {
|
||||
targetURL := claudeAPIURL
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||
if !allowedHeaders[lowerKey] {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 覆盖入站鉴权残留,并注入上游认证
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Del("cookie")
|
||||
req.Header.Set("x-api-key", token)
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
) (*streamingResult, error) {
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
}
|
||||
|
||||
writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
|
||||
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if contentType == "" {
|
||||
contentType = "text/event-stream"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
if c.Writer.Header().Get("Cache-Control") == "" {
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
}
|
||||
if c.Writer.Header().Get("Connection") == "" {
|
||||
c.Header("Connection", "keep-alive")
|
||||
}
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
maxLineSize := defaultMaxLineSize
|
||||
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
|
||||
maxLineSize = s.cfg.Gateway.MaxLineSize
|
||||
}
|
||||
scanBuf := getSSEScannerBuf64K()
|
||||
scanner.Buffer(scanBuf[:0], maxLineSize)
|
||||
defer putSSEScannerBuf64K(scanBuf)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if data, ok := extractAnthropicSSEDataLine(line); ok {
|
||||
trimmed := strings.TrimSpace(data)
|
||||
if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsagePassthrough(data, usage)
|
||||
}
|
||||
|
||||
if !clientDisconnected {
|
||||
if _, err := io.WriteString(w, line); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else if _, err := io.WriteString(w, "\n"); err != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID)
|
||||
} else if line == "" {
|
||||
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
if !clientDisconnected {
|
||||
// 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if clientDisconnected {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v",
|
||||
account.ID, resp.Header.Get("x-request-id"), err, ctx.Err())
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
if errors.Is(err, bufio.ErrTooLong) {
|
||||
logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
|
||||
func extractAnthropicSSEDataLine(line string) (string, bool) {
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
return "", false
|
||||
}
|
||||
start := len("data:")
|
||||
for start < len(line) {
|
||||
if line[start] != ' ' && line[start] != '\t' {
|
||||
break
|
||||
}
|
||||
start++
|
||||
}
|
||||
return line[start:], true
|
||||
}
|
||||
|
||||
func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) {
|
||||
if usage == nil || data == "" || data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
parsed := gjson.Parse(data)
|
||||
switch parsed.Get("type").String() {
|
||||
case "message_start":
|
||||
msgUsage := parsed.Get("message.usage")
|
||||
if msgUsage.Exists() {
|
||||
usage.InputTokens = int(msgUsage.Get("input_tokens").Int())
|
||||
usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int())
|
||||
usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int())
|
||||
|
||||
// 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。
|
||||
cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() || cc1h.Exists() {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
case "message_delta":
|
||||
deltaUsage := parsed.Get("usage")
|
||||
if deltaUsage.Exists() {
|
||||
if v := deltaUsage.Get("input_tokens").Int(); v > 0 {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v := deltaUsage.Get("output_tokens").Int(); v > 0 {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
|
||||
cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens")
|
||||
cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens")
|
||||
if cc5m.Exists() && cc5m.Int() > 0 {
|
||||
usage.CacheCreation5mTokens = int(cc5m.Int())
|
||||
}
|
||||
if cc1h.Exists() && cc1h.Int() > 0 {
|
||||
usage.CacheCreation1hTokens = int(cc1h.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 {
|
||||
usage.CacheReadInputTokens = int(cached)
|
||||
}
|
||||
if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 {
|
||||
usage.CacheReadInputTokens = int(cached)
|
||||
}
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 {
|
||||
cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int()
|
||||
cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int()
|
||||
if cc5m == 0 && cc1h == 0 {
|
||||
cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int()
|
||||
cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int()
|
||||
}
|
||||
total := cc5m + cc1h
|
||||
if total > 0 {
|
||||
usage.CacheCreationInputTokens = int(total)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage {
|
||||
usage := &ClaudeUsage{}
|
||||
if len(body) == 0 {
|
||||
return usage
|
||||
}
|
||||
|
||||
parsed := gjson.ParseBytes(body)
|
||||
usageNode := parsed.Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage
|
||||
}
|
||||
|
||||
usage.InputTokens = int(usageNode.Get("input_tokens").Int())
|
||||
usage.OutputTokens = int(usageNode.Get("output_tokens").Int())
|
||||
usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int())
|
||||
usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int())
|
||||
|
||||
cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int()
|
||||
cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int()
|
||||
if cc5m > 0 || cc1h > 0 {
|
||||
usage.CacheCreation5mTokens = int(cc5m)
|
||||
usage.CacheCreation1hTokens = int(cc1h)
|
||||
}
|
||||
if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) {
|
||||
usage.CacheCreationInputTokens = int(cc5m + cc1h)
|
||||
}
|
||||
if usage.CacheReadInputTokens == 0 {
|
||||
if cached := usageNode.Get("cached_tokens").Int(); cached > 0 {
|
||||
usage.CacheReadInputTokens = int(cached)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
) (*ClaudeUsage, error) {
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
}
|
||||
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usage := parseClaudeUsageFromResponseBody(body)
|
||||
|
||||
writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders)
|
||||
return
|
||||
}
|
||||
if v := strings.TrimSpace(src.Get("Content-Type")); v != "" {
|
||||
dst.Set("Content-Type", v)
|
||||
}
|
||||
if v := strings.TrimSpace(src.Get("x-request-id")); v != "" {
|
||||
dst.Set("x-request-id", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
@@ -5082,6 +5618,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
@@ -5241,6 +5781,158 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
|
||||
return err
|
||||
}
|
||||
if tokenType != "apikey" {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type")
|
||||
return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType)
|
||||
}
|
||||
|
||||
upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: 0,
|
||||
Passthrough: true,
|
||||
Kind: "request_error",
|
||||
Message: sanitizeUpstreamErrorMessage(err.Error()),
|
||||
})
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
upstreamDetail := ""
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Passthrough: true,
|
||||
Kind: "http_error",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
|
||||
errMsg := "Upstream request failed"
|
||||
switch resp.StatusCode {
|
||||
case 429:
|
||||
errMsg = "Rate limit exceeded"
|
||||
case 529:
|
||||
errMsg = "Service overloaded"
|
||||
}
|
||||
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
|
||||
if upstreamMsg == "" {
|
||||
return fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
|
||||
}
|
||||
|
||||
writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg)
|
||||
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, respBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||
ctx context.Context,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
body []byte,
|
||||
token string,
|
||||
) (*http.Request, error) {
|
||||
targetURL := claudeAPICountTokensURL
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL != "" {
|
||||
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c != nil && c.Request != nil {
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(strings.TrimSpace(key))
|
||||
if !allowedHeaders[lowerKey] {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Del("authorization")
|
||||
req.Header.Del("x-api-key")
|
||||
req.Header.Del("x-goog-api-key")
|
||||
req.Header.Del("cookie")
|
||||
req.Header.Set("x-api-key", token)
|
||||
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
|
||||
Reference in New Issue
Block a user