Add invalid-request fallback routing
This commit is contained in:
@@ -108,6 +108,8 @@ type CreateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
@@ -130,6 +132,8 @@ type UpdateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
@@ -572,24 +576,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
|
||||
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
// 校验无效请求兜底分组
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
}
|
||||
}
|
||||
|
||||
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
|
||||
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||
// platform/subscriptionType: 当前分组的有效平台/订阅类型
|
||||
// fallbackGroupID: 兜底分组 ID
|
||||
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
|
||||
if platform != PlatformAnthropic && platform != PlatformAntigravity {
|
||||
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
|
||||
}
|
||||
if subscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("subscription groups cannot set invalid request fallback")
|
||||
}
|
||||
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||
return fmt.Errorf("cannot set self as invalid request fallback group")
|
||||
}
|
||||
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
if fallbackGroup.Platform != PlatformAnthropic {
|
||||
return fmt.Errorf("fallback group must be anthropic platform")
|
||||
}
|
||||
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("fallback group cannot be subscription type")
|
||||
}
|
||||
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.FallbackGroupID = nil
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
|
||||
if input.FallbackGroupIDOnInvalidRequest != nil {
|
||||
if *input.FallbackGroupIDOnInvalidRequest > 0 {
|
||||
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
|
||||
} else {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
}
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
|
||||
@@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
updated *Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback *Group
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "openai_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "antigravity_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "subscription_group",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
wantMessage: "fallback group cannot be subscription type",
|
||||
},
|
||||
{
|
||||
name: "nested_fallback",
|
||||
fallback: &Group{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||
},
|
||||
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fallbackID := tc.fallback.ID
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: tc.fallback,
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.wantMessage)
|
||||
require.Nil(t, repo.created)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group not found")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
zero := int64(0)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
clear := int64(0)
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
FallbackGroupIDOnInvalidRequest: &clear,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
@@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
RequestID string
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (e *PromptTooLongError) Error() string {
|
||||
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||||
}
|
||||
|
||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
@@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "prompt_too_long",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &PromptTooLongError{
|
||||
StatusCode: resp.StatusCode,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Body: respBody,
|
||||
}
|
||||
}
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
@@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isPromptTooLongError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
return strings.Contains(msg, "prompt is too long")
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parseNestedMessage := func(msg string) string {
|
||||
trimmed := strings.TrimSpace(msg)
|
||||
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
|
||||
return ""
|
||||
}
|
||||
var nested map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
|
||||
return ""
|
||||
}
|
||||
if errObj, ok := nested["error"].(map[string]any); ok {
|
||||
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
}
|
||||
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Google-style: {"error": {"message": "..."}}
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: top-level message
|
||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
statusStr := "UNKNOWN"
|
||||
switch status {
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||
}
|
||||
|
||||
func TestIsPromptTooLongError(t *testing.T) {
|
||||
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||||
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||||
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||||
}
|
||||
|
||||
type httpUpstreamStub struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-opus-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"max_tokens": 1,
|
||||
"stream": false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Nil(t, result)
|
||||
|
||||
var promptErr *PromptTooLongError
|
||||
require.ErrorAs(t, err, &promptErr)
|
||||
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||||
require.Equal(t, "req-1", promptErr.RequestID)
|
||||
require.NotEmpty(t, promptErr.Body)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||||
}
|
||||
|
||||
@@ -23,20 +23,21 @@ type APIKeyAuthUserSnapshot struct {
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
|
||||
@@ -207,22 +207,23 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -250,23 +251,24 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
|
||||
@@ -55,6 +55,15 @@ func shortSessionHash(sessionHash string) string {
|
||||
return sessionHash[:8]
|
||||
}
|
||||
|
||||
func normalizeClaudeModelForAnthropic(requestedModel string) string {
|
||||
for _, prefix := range anthropicPrefixMappings {
|
||||
if strings.HasPrefix(requestedModel, prefix) {
|
||||
return prefix
|
||||
}
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
@@ -71,6 +80,12 @@ var (
|
||||
"You are a file search specialist for Claude Code", // Explore Agent 版
|
||||
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
|
||||
}
|
||||
|
||||
anthropicPrefixMappings = []string{
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
}
|
||||
)
|
||||
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
@@ -951,6 +966,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
return s.resolveGroupByID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||
return nil
|
||||
@@ -1016,7 +1035,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
@@ -1719,6 +1738,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
// Antigravity 平台使用专门的模型支持检查
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
if account.Platform == PlatformAnthropic {
|
||||
requestedModel = normalizeClaudeModelForAnthropic(requestedModel)
|
||||
}
|
||||
// 其他平台使用账户的模型支持检查
|
||||
return account.IsModelSupported(requestedModel)
|
||||
}
|
||||
@@ -2115,17 +2137,29 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||
originalModel := reqModel
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
@@ -3426,16 +3460,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
// 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
|
||||
if reqModel != "" {
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic {
|
||||
normalized := normalizeClaudeModelForAnthropic(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
|
||||
@@ -29,6 +29,8 @@ type Group struct {
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
// 无效请求兜底分组(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
|
||||
Reference in New Issue
Block a user