Merge pull request #657 from alfadb/fix/count-tokens-404-passthrough
fix(gateway): count_tokens 不支持时返回 404 而非伪造的 200
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -262,44 +263,44 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *testing.T) {
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
respBody string
|
||||
wantFallback bool
|
||||
name string
|
||||
statusCode int
|
||||
respBody string
|
||||
wantPassthrough bool
|
||||
}{
|
||||
{
|
||||
name: "404 endpoint not found triggers fallback",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||
wantFallback: true,
|
||||
name: "404 endpoint not found passes through as 404",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "404 generic not found triggers fallback",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
wantFallback: true,
|
||||
name: "404 generic not found passes through as 404",
|
||||
statusCode: http.StatusNotFound,
|
||||
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||
wantPassthrough: true,
|
||||
},
|
||||
{
|
||||
name: "400 Invalid URL does not fallback",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
|
||||
wantFallback: false,
|
||||
name: "400 Invalid URL does not passthrough",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "400 model error does not fallback",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
|
||||
wantFallback: false,
|
||||
name: "400 model error does not passthrough",
|
||||
statusCode: http.StatusBadRequest,
|
||||
respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
{
|
||||
name: "500 internal error does not fallback",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
|
||||
wantFallback: false,
|
||||
name: "500 internal error does not passthrough",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
|
||||
wantPassthrough: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -345,10 +346,16 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *t
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
|
||||
if tt.wantFallback {
|
||||
if tt.wantPassthrough {
|
||||
// 返回 nil(不记录为错误),HTTP 状态码 404 + Anthropic 错误体
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.JSONEq(t, `{"input_tokens":0}`, rec.Body.String())
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
var errResp map[string]any
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &errResp))
|
||||
require.Equal(t, "error", errResp["type"])
|
||||
errObj, ok := errResp["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "not_found_error", errObj["type"])
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
require.Equal(t, tt.statusCode, rec.Code)
|
||||
|
||||
@@ -6015,9 +6015,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
|
||||
}
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
// Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。
|
||||
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6222,12 +6223,13 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
// 中转站不支持 count_tokens 端点时(404),降级返回空值,客户端会 fallback 到本地估算。
|
||||
// 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。
|
||||
// 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
logger.LegacyPrintf("service.gateway",
|
||||
"[count_tokens] Upstream does not support count_tokens (404), returning fallback: account=%d name=%s msg=%s",
|
||||
"[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s",
|
||||
account.ID, account.Name, truncateString(upstreamMsg, 512))
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user