feat: 模型映射应用 /v1/messages/count_tokens端点
This commit is contained in:
@@ -171,8 +171,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
require.NotNil(t, result)
|
require.NotNil(t, result)
|
||||||
require.True(t, result.Stream)
|
require.True(t, result.Stream)
|
||||||
|
|
||||||
require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体")
|
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(upstream.lastBody, "model").String(), "透传模式应应用账号级模型映射")
|
||||||
require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String())
|
|
||||||
|
|
||||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||||
@@ -190,7 +189,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
bodyBytes, ok := rawBody.([]byte)
|
bodyBytes, ok := rawBody.([]byte)
|
||||||
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
|
||||||
require.Equal(t, body, bodyBytes)
|
require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
|
||||||
@@ -253,8 +252,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体")
|
require.Equal(t, "claude-3-opus-20240229", gjson.GetBytes(upstream.lastBody, "model").String(), "count_tokens 透传模式应应用账号级模型映射")
|
||||||
require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String())
|
|
||||||
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
require.Empty(t, upstream.lastReq.Header.Get("authorization"))
|
||||||
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
require.Empty(t, upstream.lastReq.Header.Get("cookie"))
|
||||||
@@ -263,6 +261,273 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBo
|
|||||||
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
require.Empty(t, rec.Header().Get("Set-Cookie"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases 覆盖透传模式下模型映射的各种边界情况
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingEdgeCases(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
modelMapping map[string]any // nil = 不配置映射
|
||||||
|
expectedModel string
|
||||||
|
endpoint string // "messages" or "count_tokens"
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Forward: 无映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: nil,
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 空映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 模型不在映射表中时不改写",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 精确匹配映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Forward: 通配符映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "messages",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 无映射配置时不改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: nil,
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 模型不在映射表中时不改写",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-3-haiku-20240307": "claude-3-opus-20240229"},
|
||||||
|
expectedModel: "claude-sonnet-4-20250514",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 精确匹配映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CountTokens: 通配符映射应改写模型",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
modelMapping: map[string]any{"claude-sonnet-4-*": "claude-sonnet-4-5-20241022"},
|
||||||
|
expectedModel: "claude-sonnet-4-5-20241022",
|
||||||
|
endpoint: "count_tokens",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
body := []byte(`{"model":"` + tt.model + `","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: tt.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
credentials := map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
}
|
||||||
|
if tt.modelMapping != nil {
|
||||||
|
credentials["model_mapping"] = tt.modelMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 300,
|
||||||
|
Name: "edge-case-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: credentials,
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.endpoint == "messages" {
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||||
|
parsed.Stream = false
|
||||||
|
|
||||||
|
upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":5,"output_tokens":3}}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamJSON)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||||
|
"Forward 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||||
|
} else {
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":42}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.expectedModel, gjson.GetBytes(upstream.lastBody, "model").String(),
|
||||||
|
"CountTokens 上游请求体中的模型应为: %s", tt.expectedModel)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields
|
||||||
|
// 确保模型映射只替换 model 字段,不影响请求体中的其他字段
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFields(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
// 包含复杂字段的请求体:system、thinking、messages
|
||||||
|
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world"}]}],"thinking":{"type":"enabled","budget_tokens":5000},"max_tokens":1024}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: "claude-sonnet-4-20250514",
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":42}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 301,
|
||||||
|
Name: "preserve-fields-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"model_mapping": map[string]any{"claude-sonnet-4-20250514": "claude-sonnet-4-5-20241022"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sentBody := upstream.lastBody
|
||||||
|
require.Equal(t, "claude-sonnet-4-5-20241022", gjson.GetBytes(sentBody, "model").String(), "model 应被映射")
|
||||||
|
require.Equal(t, "You are a helpful assistant.", gjson.GetBytes(sentBody, "system.0.text").String(), "system 字段不应被修改")
|
||||||
|
require.Equal(t, "hello world", gjson.GetBytes(sentBody, "messages.0.content.0.text").String(), "messages 字段不应被修改")
|
||||||
|
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String(), "thinking 字段不应被修改")
|
||||||
|
require.Equal(t, int64(5000), gjson.GetBytes(sentBody, "thinking.budget_tokens").Int(), "thinking.budget_tokens 不应被修改")
|
||||||
|
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||||
|
// 确保空模型名不会触发映射逻辑
|
||||||
|
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||||
|
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
Model: "", // 空模型
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamRespBody := `{"input_tokens":10}`
|
||||||
|
upstream := &anthropicHTTPUpstreamRecorder{
|
||||||
|
resp: &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &GatewayService{
|
||||||
|
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||||
|
httpUpstream: upstream,
|
||||||
|
rateLimitService: &RateLimitService{},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
ID: 302,
|
||||||
|
Name: "empty-model-test",
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"api_key": "upstream-key",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"model_mapping": map[string]any{"*": "claude-3-opus-20240229"},
|
||||||
|
},
|
||||||
|
Extra: map[string]any{"anthropic_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 空模型名时,body 应原样透传,不应触发映射
|
||||||
|
require.Equal(t, body, upstream.lastBody, "空模型名时请求体不应被修改")
|
||||||
|
}
|
||||||
|
|
||||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -3889,7 +3889,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime)
|
passthroughBody := parsed.Body
|
||||||
|
passthroughModel := parsed.Model
|
||||||
|
if passthroughModel != "" {
|
||||||
|
if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel {
|
||||||
|
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||||
|
logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||||
|
passthroughModel = mappedModel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
@@ -6781,7 +6790,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
|
||||||
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body)
|
passthroughBody := parsed.Body
|
||||||
|
if reqModel := parsed.Model; reqModel != "" {
|
||||||
|
if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel {
|
||||||
|
passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel)
|
||||||
|
logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
body := parsed.Body
|
body := parsed.Body
|
||||||
|
|||||||
Reference in New Issue
Block a user