From aeb464f3ca27c9bf29ececdc1245fd4ad19883b7 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Mar 2026 14:49:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E6=98=A0=E5=B0=84?= =?UTF-8?q?=E5=BA=94=E7=94=A8=20/v1/messages/count=5Ftokens=E7=AB=AF?= =?UTF-8?q?=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...teway_anthropic_apikey_passthrough_test.go | 275 +++++++++++++++++- backend/internal/service/gateway_service.go | 20 +- 2 files changed, 288 insertions(+), 7 deletions(-) diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index f8c0ecda..5dcda1de 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.NotNil(t, result) require.True(t, result.Stream) - require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") - require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) @@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd require.True(t, ok) bodyBytes, ok := rawBody.([]byte) require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") - require.Equal(t, body, bodyBytes) + require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型") } func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { @@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo err := svc.ForwardCountTokens(context.Background(), c, account, parsed) require.NoError(t, err) - require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") - require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射") require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) require.Empty(t, upstream.lastReq.Header.Get("authorization")) require.Empty(t, upstream.lastReq.Header.Get("cookie")) @@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo require.Empty(t, rec.Header().Get("Set-Cookie")) } +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + model string + modelMapping map[string]any // nil = 不配置映射 + expectedModel string + endpoint string // "messages" or "count_tokens" + }{ + { + name: "Forward: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 空映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "messages", + }, + { + name: "Forward: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "Forward: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "messages", + }, + { + name: "CountTokens: 无映射配置时不改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: nil, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 模型不在映射表中时不改写", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"}, + expectedModel: "claude-sonnet-4-20250514", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 精确匹配映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + { + name: "CountTokens: 通配符映射应改写模型", + model: "claude-sonnet-4-20250514", + modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"}, + expectedModel: "claude-sonnet-4-5-20241022", + endpoint: "count_tokens", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: tt.model, + } + + credentials := map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + } + if tt.modelMapping != nil { + credentials["model_mapping"] = tt.modelMapping + } + + account := &Account{ + ID: 300, + Name: "edge-case-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: credentials, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + if tt.endpoint == "messages" { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + parsed.Stream = false + + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "Forward 上游请求体中的模型应为: %s", tt.expectedModel) + } else { + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(), + "CountTokens 上游请求体中的模型应为: %s", tt.expectedModel) + } + }) + } +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields +// 确保模型映射只替换 model 字段,不影响请求体中的其他字段 +func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + // 包含复杂字段的请求体:system、thinking、messages + body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-sonnet-4-20250514", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 301, + Name: "preserve-fields-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + sentBody := upstream.lastBody + require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射") + require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改") + require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改") + require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改") + require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改") + require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改") +} + +// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping +// 确保空模型名不会触发映射逻辑 +func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "", // 空模型 + } + + upstreamRespBody := `{"input_tokens":10}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 302, + Name: "empty-model-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"*": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + // 空模型名时,body 应原样透传,不应触发映射 + require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改") +} + func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 02f9a6a3..d26ed24e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3889,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + passthroughBody := parsed.Body + passthroughModel := parsed.Model + if passthroughModel != "" { + if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + passthroughModel = mappedModel + } + } + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) } body := parsed.Body @@ -6781,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + passthroughBody := parsed.Body + if reqModel := parsed.Model; reqModel != "" { + if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + } + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) } body := parsed.Body