Merge pull request #635 from alfadb/fix/count-tokens-fallback-for-proxy
fix: count_tokens 端点不支持时降级返回空值
This commit is contained in:
@@ -262,6 +262,101 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFallbackOn404(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
respBody string
|
||||||
|
wantFallback bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "404 endpoint not found triggers fallback",
|
||||||
|
statusCode: http.StatusNotFound,
|
||||||
|
respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`,
|
||||||
|
wantFallback: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "404 generic not found triggers fallback",
|
||||||
|
statusCode: http.StatusNotFound,
|
||||||
|
respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`,
|
||||||
|
wantFallback: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "400 Invalid URL does not fallback",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`,
|
||||||
|
wantFallback: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "400 model error does not fallback",
|
||||||
|
statusCode: http.StatusBadRequest,
|
||||||
|
respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`,
|
||||||
|
wantFallback: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500 internal error does not fallback",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
respBody: `{"error":{"message":"internal error","type":"api_error"}}`,
|
||||||
|
wantFallback: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-5-20250929","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
parsed := &ParsedRequest{Body: body, Model: "claude-sonnet-4-5-20250929"}
|
||||||
|
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: tt.statusCode,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(tt.respBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{
|
||||||
|
Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize},
|
||||||
|
},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 200,
|
||||||
|
Name: "proxy-acc",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "sk-proxy",
|
||||||
|
"base_url": "https://proxy.example.com",
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
|
||||||
|
if tt.wantFallback {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.JSONEq(t, `{"input_tokens":0}`, rec.Body.String())
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, tt.statusCode, rec.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -6199,6 +6199,16 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex
|
|||||||
|
|
||||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
|
||||||
|
// 中转站不支持 count_tokens 端点时(404),降级返回空值,客户端会 fallback 到本地估算。
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
logger.LegacyPrintf("service.gateway",
|
||||||
|
"[count_tokens] Upstream does not support count_tokens (404), returning fallback: account=%d name=%s msg=%s",
|
||||||
|
account.ID, account.Name, truncateString(upstreamMsg, 512))
|
||||||
|
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
upstreamDetail := ""
|
upstreamDetail := ""
|
||||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
|||||||
Reference in New Issue
Block a user