fix(antigravity): bill with mapped model and use final model key for rate limiting
- Use mapped model (billingModel) instead of original request model for billing - Use resolveFinalAntigravityModelKey for 429 rate limit model key, ensuring rate limit records match the actual upstream model - Add regression tests for both fixes
This commit is contained in:
@@ -87,7 +87,6 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
|
||||||
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL"
|
||||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||||
)
|
)
|
||||||
@@ -1309,6 +1308,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
// 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本
|
||||||
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive")
|
||||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
||||||
|
billingModel := mappedModel
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -1622,7 +1622,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel, // 使用原始模型用于计费和日志
|
Model: billingModel, // 使用映射模型用于计费和日志
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -1976,6 +1976,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
if mappedModel == "" {
|
if mappedModel == "" {
|
||||||
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel))
|
||||||
}
|
}
|
||||||
|
billingModel := mappedModel
|
||||||
|
|
||||||
// 获取 access_token
|
// 获取 access_token
|
||||||
if s.tokenProvider == nil {
|
if s.tokenProvider == nil {
|
||||||
@@ -2205,7 +2206,7 @@ handleSuccess:
|
|||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Usage: *usage,
|
Usage: *usage,
|
||||||
Model: originalModel,
|
Model: billingModel,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
Duration: time.Since(startTime),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
@@ -2650,7 +2651,16 @@ func (s *AntigravityGatewayService) handleUpstreamError(
|
|||||||
defaultDur := s.getDefaultRateLimitDuration()
|
defaultDur := s.getDefaultRateLimitDuration()
|
||||||
|
|
||||||
// 尝试解析模型 key 并设置模型级限流
|
// 尝试解析模型 key 并设置模型级限流
|
||||||
modelKey := resolveAntigravityModelKey(requestedModel)
|
//
|
||||||
|
// 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6),
|
||||||
|
// 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。
|
||||||
|
// 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。
|
||||||
|
modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||||
|
if strings.TrimSpace(modelKey) == "" {
|
||||||
|
// 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过),
|
||||||
|
// 保持旧行为作为兜底,避免完全丢失模型级限流记录。
|
||||||
|
modelKey = resolveAntigravityModelKey(requestedModel)
|
||||||
|
}
|
||||||
if modelKey != "" {
|
if modelKey != "" {
|
||||||
ra := s.resolveResetTime(resetAt, defaultDur)
|
ra := s.resolveResetTime(resetAt, defaultDur)
|
||||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil {
|
||||||
@@ -3889,7 +3899,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
return nil, fmt.Errorf("missing model")
|
return nil, fmt.Errorf("missing model")
|
||||||
}
|
}
|
||||||
originalModel := claudeReq.Model
|
originalModel := claudeReq.Model
|
||||||
billingModel := originalModel
|
|
||||||
|
|
||||||
// 构建上游请求 URL
|
// 构建上游请求 URL
|
||||||
upstreamURL := baseURL + "/v1/messages"
|
upstreamURL := baseURL + "/v1/messages"
|
||||||
@@ -3942,7 +3951,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
_, _ = c.Writer.Write(respBody)
|
_, _ = c.Writer.Write(respBody)
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
Model: billingModel,
|
Model: originalModel,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3983,7 +3992,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||||
|
|
||||||
return &ForwardResult{
|
return &ForwardResult{
|
||||||
Model: billingModel,
|
Model: originalModel,
|
||||||
Stream: claudeReq.Stream,
|
Stream: claudeReq.Stream,
|
||||||
Duration: duration,
|
Duration: duration,
|
||||||
FirstTokenMs: firstTokenMs,
|
FirstTokenMs: firstTokenMs,
|
||||||
|
|||||||
@@ -134,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int,
|
|||||||
return s.resp, s.err
|
return s.resp, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type antigravitySettingRepoStub struct{}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||||
|
panic("unexpected Get call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||||
|
return "", ErrSettingNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||||
|
panic("unexpected Set call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||||
|
panic("unexpected GetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||||
|
panic("unexpected SetMultiple call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||||
|
panic("unexpected GetAll call")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error {
|
||||||
|
panic("unexpected Delete call")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
writer := httptest.NewRecorder()
|
writer := httptest.NewRecorder()
|
||||||
@@ -160,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
svc := &AntigravityGatewayService{
|
svc := &AntigravityGatewayService{
|
||||||
tokenProvider: &AntigravityTokenProvider{},
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := &Account{
|
||||||
@@ -418,6 +449,113 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(
|
|||||||
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAntigravityGatewayService_Forward_BillsWithMappedModel
|
||||||
|
// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型
|
||||||
|
func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []map[string]any{
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
},
|
||||||
|
"max_tokens": 16,
|
||||||
|
"stream": true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-bill-1"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
const mappedModel = "gemini-3-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 5,
|
||||||
|
Name: "acc-forward-billing",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"claude-sonnet-4-5": mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.Forward(context.Background(), c, account, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel
|
||||||
|
// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型
|
||||||
|
func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
writer := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(writer)
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hello"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n")
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"X-Request-Id": []string{"req-bill-2"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(upstreamBody)),
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}),
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||||
|
}
|
||||||
|
|
||||||
|
const mappedModel = "gemini-3-pro-high"
|
||||||
|
account := &Account{
|
||||||
|
ID: 6,
|
||||||
|
Name: "acc-gemini-billing",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "token",
|
||||||
|
"model_mapping": map[string]any{
|
||||||
|
"gemini-2.5-flash": mappedModel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, mappedModel, result.Model)
|
||||||
|
}
|
||||||
|
|
||||||
// TestStreamUpstreamResponse_UsageAndFirstToken
|
// TestStreamUpstreamResponse_UsageAndFirstToken
|
||||||
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间
|
||||||
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) {
|
||||||
|
|||||||
@@ -197,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) {
|
|||||||
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景
|
||||||
|
// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking)
|
||||||
|
func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) {
|
||||||
|
repo := &stubAntigravityAccountRepo{}
|
||||||
|
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||||
|
account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity}
|
||||||
|
|
||||||
|
body := buildGeminiRateLimitBody("5s")
|
||||||
|
|
||||||
|
result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false)
|
||||||
|
|
||||||
|
require.Nil(t, result)
|
||||||
|
require.Len(t, repo.modelRateLimitCalls, 1)
|
||||||
|
require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey)
|
||||||
|
}
|
||||||
|
|
||||||
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景
|
||||||
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号
|
||||||
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user