Merge branch 'main' into test
This commit is contained in:
@@ -116,9 +116,14 @@ type CreateGroupInput struct {
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string
|
||||
// 从指定分组复制账号(创建分组后在同一事务内绑定)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -145,9 +150,14 @@ type UpdateGroupInput struct {
|
||||
SoraVideoPricePerRequestHD *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
MCPXMLInject *bool
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes *[]string
|
||||
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
|
||||
CopyAccountsFromGroupIDs []int64
|
||||
}
|
||||
@@ -611,6 +621,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest
|
||||
if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
// 校验无效请求兜底分组
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// MCPXMLInject:默认为 true,仅当显式传入 false 时关闭
|
||||
mcpXMLInject := true
|
||||
if input.MCPXMLInject != nil {
|
||||
mcpXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
|
||||
// 如果指定了复制账号的源分组,先获取账号 ID 列表
|
||||
var accountIDsToCopy []int64
|
||||
@@ -645,26 +671,29 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
SoraImagePrice360: soraImagePrice360,
|
||||
SoraImagePrice540: soraImagePrice540,
|
||||
SoraVideoPricePerRequest: soraVideoPrice,
|
||||
SoraVideoPricePerRequestHD: soraVideoPriceHD,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
SoraImagePrice360: soraImagePrice360,
|
||||
SoraImagePrice540: soraImagePrice540,
|
||||
SoraVideoPricePerRequest: soraVideoPrice,
|
||||
SoraVideoPricePerRequestHD: soraVideoPriceHD,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
|
||||
ModelRouting: input.ModelRouting,
|
||||
MCPXMLInject: mcpXMLInject,
|
||||
SupportedModelScopes: input.SupportedModelScopes,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -735,6 +764,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
}
|
||||
}
|
||||
|
||||
// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性
|
||||
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||
// platform/subscriptionType: 当前分组的有效平台/订阅类型
|
||||
// fallbackGroupID: 兜底分组 ID
|
||||
func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error {
|
||||
if platform != PlatformAnthropic && platform != PlatformAntigravity {
|
||||
return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups")
|
||||
}
|
||||
if subscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("subscription groups cannot set invalid request fallback")
|
||||
}
|
||||
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||
return fmt.Errorf("cannot set self as invalid request fallback group")
|
||||
}
|
||||
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
if fallbackGroup.Platform != PlatformAnthropic {
|
||||
return fmt.Errorf("fallback group must be anthropic platform")
|
||||
}
|
||||
if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription {
|
||||
return fmt.Errorf("fallback group cannot be subscription type")
|
||||
}
|
||||
if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil {
|
||||
return fmt.Errorf("fallback group cannot have invalid request fallback configured")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -813,6 +873,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
group.FallbackGroupID = nil
|
||||
}
|
||||
}
|
||||
fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest
|
||||
if input.FallbackGroupIDOnInvalidRequest != nil {
|
||||
if *input.FallbackGroupIDOnInvalidRequest > 0 {
|
||||
fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest
|
||||
} else {
|
||||
fallbackOnInvalidRequest = nil
|
||||
}
|
||||
}
|
||||
if fallbackOnInvalidRequest != nil {
|
||||
if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
@@ -821,6 +895,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.ModelRoutingEnabled != nil {
|
||||
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||
}
|
||||
if input.MCPXMLInject != nil {
|
||||
group.MCPXMLInject = *input.MCPXMLInject
|
||||
}
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
if input.SupportedModelScopes != nil {
|
||||
group.SupportedModelScopes = *input.SupportedModelScopes
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
updated *Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback *Group
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "openai_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "antigravity_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "subscription_group",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
wantMessage: "fallback group cannot be subscription type",
|
||||
},
|
||||
{
|
||||
name: "nested_fallback",
|
||||
fallback: &Group{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||
},
|
||||
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fallbackID := tc.fallback.ID
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: tc.fallback,
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.wantMessage)
|
||||
require.Nil(t, repo.created)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group not found")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
zero := int64(0)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
clear := int64(0)
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
FallbackGroupIDOnInvalidRequest: &clear,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
@@ -13,23 +13,34 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityMaxRetries = 3
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityDefaultMaxRetries = 3
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
|
||||
const (
|
||||
antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES"
|
||||
antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES"
|
||||
antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE"
|
||||
antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT"
|
||||
antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE"
|
||||
antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT"
|
||||
antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL"
|
||||
antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS"
|
||||
)
|
||||
|
||||
// antigravityRetryLoopParams 重试循环的参数
|
||||
type antigravityRetryLoopParams struct {
|
||||
@@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct {
|
||||
action string
|
||||
body []byte
|
||||
quotaScope AntigravityQuotaScope
|
||||
maxRetries int
|
||||
c *gin.Context
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
@@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// PromptTooLongError 表示上游明确返回 prompt too long
|
||||
type PromptTooLongError struct {
|
||||
StatusCode int
|
||||
RequestID string
|
||||
Body []byte
|
||||
}
|
||||
|
||||
func (e *PromptTooLongError) Error() string {
|
||||
return fmt.Sprintf("prompt too long: status=%d", e.StatusCode)
|
||||
}
|
||||
|
||||
// antigravityRetryLoop 执行带 URL fallback 的重试循环
|
||||
func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) {
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
baseURLs := antigravity.ForwardBaseURLs()
|
||||
availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs)
|
||||
if len(availableURLs) == 0 {
|
||||
availableURLs = antigravity.BaseURLs
|
||||
availableURLs = baseURLs
|
||||
}
|
||||
|
||||
maxRetries := p.maxRetries
|
||||
if maxRetries <= 0 {
|
||||
maxRetries = antigravityDefaultMaxRetries
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
@@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
usedBaseURL = baseURL
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err())
|
||||
@@ -109,8 +138,8 @@ urlFallbackLoop:
|
||||
log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1])
|
||||
continue urlFallbackLoop
|
||||
}
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err)
|
||||
if attempt < maxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err)
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
@@ -134,7 +163,7 @@ urlFallbackLoop:
|
||||
}
|
||||
|
||||
// 账户/模型配额限流,重试 3 次(指数退避)
|
||||
if attempt < antigravityMaxRetries {
|
||||
if attempt < maxRetries {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
@@ -147,7 +176,7 @@ urlFallbackLoop:
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
||||
log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
@@ -171,7 +200,7 @@ urlFallbackLoop:
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
if attempt < maxRetries {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{
|
||||
@@ -184,7 +213,7 @@ urlFallbackLoop:
|
||||
Message: upstreamMsg,
|
||||
Detail: getUpstreamDetail(respBody),
|
||||
})
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500))
|
||||
log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500))
|
||||
if !sleepAntigravityBackoffWithContext(p.ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
@@ -390,6 +419,11 @@ type TestConnectionResult struct {
|
||||
// TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费)
|
||||
// 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择
|
||||
func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||
// 上游透传账号使用专用测试方法
|
||||
if account.Type == AccountTypeUpstream {
|
||||
return s.testUpstreamConnection(ctx, account, modelID)
|
||||
}
|
||||
|
||||
// 获取 token
|
||||
if s.tokenProvider == nil {
|
||||
return nil, errors.New("antigravity token provider not configured")
|
||||
@@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// testUpstreamConnection 测试上游透传账号连接
|
||||
func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) {
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return nil, errors.New("upstream account missing base_url or api_key")
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// 使用 Claude 模型进行测试
|
||||
if modelID == "" {
|
||||
modelID = "claude-sonnet-4-20250514"
|
||||
}
|
||||
|
||||
// 构建最小测试请求
|
||||
testReq := map[string]any{
|
||||
"model": modelID,
|
||||
"max_tokens": 1,
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "."},
|
||||
},
|
||||
}
|
||||
requestBody, err := json.Marshal(testReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建 HTTP 请求
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL)
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 提取响应文本
|
||||
var respData map[string]any
|
||||
text := ""
|
||||
if json.Unmarshal(respBody, &respData) == nil {
|
||||
if content, ok := respData["content"].([]any); ok && len(content) > 0 {
|
||||
if block, ok := content[0].(map[string]any); ok {
|
||||
if t, ok := block["text"].(string); ok {
|
||||
text = t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TestConnectionResult{
|
||||
Text: text,
|
||||
MappedModel: modelID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildGeminiTestRequest 构建 Gemini 格式测试请求
|
||||
// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1
|
||||
func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) {
|
||||
@@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex
|
||||
}
|
||||
opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx)
|
||||
opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx)
|
||||
|
||||
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil {
|
||||
opts.EnableMCPXML = group.MCPXMLInject
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
@@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
// 上游透传账号直接转发,不走 OAuth token 刷新
|
||||
if account.Type == AccountTypeUpstream {
|
||||
return s.ForwardUpstream(ctx, c, account, body)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
@@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
originalModel := claudeReq.Model
|
||||
mappedModel := s.getMappedModel(account, claudeReq.Model)
|
||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
||||
billingModel := originalModel
|
||||
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
|
||||
billingModel = mappedModel
|
||||
}
|
||||
afterSwitch := antigravityHasAccountSwitch(ctx)
|
||||
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
handleError: s.handleUpstreamError,
|
||||
maxRetries: maxRetries,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
@@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
handleError: s.handleUpstreamError,
|
||||
maxRetries: maxRetries,
|
||||
})
|
||||
if retryErr != nil {
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
@@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500))
|
||||
}
|
||||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody
|
||||
maxBytes := 2048
|
||||
if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 {
|
||||
maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||
}
|
||||
upstreamDetail := ""
|
||||
if logBody {
|
||||
upstreamDetail = truncateString(string(respBody), maxBytes)
|
||||
}
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
AccountName: account.Name,
|
||||
UpstreamStatusCode: resp.StatusCode,
|
||||
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||
Kind: "prompt_too_long",
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &PromptTooLongError{
|
||||
StatusCode: resp.StatusCode,
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Body: respBody,
|
||||
}
|
||||
}
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
@@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Model: billingModel, // 计费模型(可按映射模型覆盖)
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Detect thinking block modification errors:
|
||||
// "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
|
||||
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isPromptTooLongError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
return strings.Contains(msg, "prompt is too long")
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parseNestedMessage := func(msg string) string {
|
||||
trimmed := strings.TrimSpace(msg)
|
||||
if trimmed == "" || !strings.HasPrefix(trimmed, "{") {
|
||||
return ""
|
||||
}
|
||||
var nested map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &nested); err != nil {
|
||||
return ""
|
||||
}
|
||||
if errObj, ok := nested["error"].(map[string]any); ok {
|
||||
if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
}
|
||||
if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Google-style: {"error": {"message": "..."}}
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: top-level message
|
||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
if innerMsg := parseNestedMessage(msg); innerMsg != "" {
|
||||
return innerMsg
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// ForwardUpstream 透传请求到上游 Antigravity 服务
|
||||
// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token
|
||||
func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
// 获取上游配置
|
||||
baseURL := strings.TrimSpace(account.GetCredential("base_url"))
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if baseURL == "" || apiKey == "" {
|
||||
return nil, fmt.Errorf("upstream account missing base_url or api_key")
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
// 解析请求获取模型信息
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
if err := json.Unmarshal(body, &claudeReq); err != nil {
|
||||
return nil, fmt.Errorf("parse claude request: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(claudeReq.Model) == "" {
|
||||
return nil, fmt.Errorf("missing model")
|
||||
}
|
||||
originalModel := claudeReq.Model
|
||||
billingModel := originalModel
|
||||
|
||||
// 构建上游请求 URL
|
||||
upstreamURL := baseURL + "/v1/messages"
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream request: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("x-api-key", apiKey) // Claude API 兼容
|
||||
|
||||
// 透传 Claude 相关 headers
|
||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
||||
req.Header.Set("anthropic-version", v)
|
||||
}
|
||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
||||
req.Header.Set("anthropic-beta", v)
|
||||
}
|
||||
|
||||
// 代理 URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
log.Printf("%s upstream request failed: %v", prefix, err)
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 429 错误时标记账号限流
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude)
|
||||
}
|
||||
|
||||
// 透传上游错误
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 处理成功响应(流式/非流式)
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
if claudeReq.Stream {
|
||||
// 流式响应:透传
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
c.Status(http.StatusOK)
|
||||
|
||||
usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime)
|
||||
} else {
|
||||
// 非流式响应:直接透传
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read upstream response: %w", err)
|
||||
}
|
||||
|
||||
// 提取 usage
|
||||
usage = s.extractClaudeUsage(respBody)
|
||||
|
||||
c.Header("Content-Type", resp.Header.Get("Content-Type"))
|
||||
c.Status(http.StatusOK)
|
||||
_, _ = c.Writer.Write(respBody)
|
||||
}
|
||||
|
||||
// 构建计费结果
|
||||
duration := time.Since(startTime)
|
||||
log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds())
|
||||
|
||||
return &ForwardResult{
|
||||
Model: billingModel,
|
||||
Stream: claudeReq.Stream,
|
||||
Duration: duration,
|
||||
FirstTokenMs: firstTokenMs,
|
||||
Usage: ClaudeUsage{
|
||||
InputTokens: usage.InputTokens,
|
||||
OutputTokens: usage.OutputTokens,
|
||||
CacheReadInputTokens: usage.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: usage.CacheCreationInputTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// streamUpstreamResponse 透传上游流式响应并提取 usage
|
||||
func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) {
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
var firstTokenRecorded bool
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
scanner.Buffer(buf, 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
|
||||
// 记录首 token 时间
|
||||
if !firstTokenRecorded && len(line) > 0 {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
firstTokenRecorded = true
|
||||
}
|
||||
|
||||
// 尝试从 message_delta 或 message_stop 事件提取 usage
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
dataStr := bytes.TrimPrefix(line, []byte("data: "))
|
||||
var event map[string]any
|
||||
if json.Unmarshal(dataStr, &event) == nil {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 透传行
|
||||
_, _ = c.Writer.Write(line)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
return usage, firstTokenMs
|
||||
}
|
||||
|
||||
// extractClaudeUsage 从非流式 Claude 响应提取 usage
|
||||
func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage {
|
||||
usage := &ClaudeUsage{}
|
||||
var resp map[string]any
|
||||
if json.Unmarshal(body, &resp) != nil {
|
||||
return usage
|
||||
}
|
||||
if u, ok := resp["usage"].(map[string]any); ok {
|
||||
if v, ok := u["input_tokens"].(float64); ok {
|
||||
usage.InputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["output_tokens"].(float64); ok {
|
||||
usage.OutputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_read_input_tokens"].(float64); ok {
|
||||
usage.CacheReadInputTokens = int(v)
|
||||
}
|
||||
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
|
||||
usage.CacheCreationInputTokens = int(v)
|
||||
}
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
mappedModel := s.getMappedModel(account, originalModel)
|
||||
billingModel := originalModel
|
||||
if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" {
|
||||
billingModel = mappedModel
|
||||
}
|
||||
afterSwitch := antigravityHasAccountSwitch(ctx)
|
||||
maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch)
|
||||
|
||||
// 获取 access_token
|
||||
if s.tokenProvider == nil {
|
||||
@@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
|
||||
filteredBody, err := filterEmptyPartsFromGeminiRequest(body)
|
||||
if err != nil {
|
||||
log.Printf("[Antigravity] Failed to filter empty parts: %v", err)
|
||||
filteredBody = body
|
||||
}
|
||||
|
||||
// Antigravity 上游要求必须包含身份提示词,注入到请求中
|
||||
injectedBody, err := injectIdentityPatchToGeminiRequest(body)
|
||||
injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
httpUpstream: s.httpUpstream,
|
||||
settingService: s.settingService,
|
||||
handleError: s.handleUpstreamError,
|
||||
maxRetries: maxRetries,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
@@ -1493,7 +1914,7 @@ handleSuccess:
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Model: billingModel,
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
@@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func antigravityHasAccountSwitch(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok {
|
||||
return v > 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func antigravityMaxRetries() int {
|
||||
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv))
|
||||
if raw == "" {
|
||||
return antigravityDefaultMaxRetries
|
||||
}
|
||||
value, err := strconv.Atoi(raw)
|
||||
if err != nil || value <= 0 {
|
||||
return antigravityDefaultMaxRetries
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func antigravityMaxRetriesAfterSwitch() int {
|
||||
raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv))
|
||||
if raw == "" {
|
||||
return antigravityMaxRetries()
|
||||
}
|
||||
value, err := strconv.Atoi(raw)
|
||||
if err != nil || value <= 0 {
|
||||
return antigravityMaxRetries()
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// antigravityMaxRetriesForModel 根据模型类型获取重试次数
|
||||
// 优先使用模型细分配置,未设置则回退到平台级配置
|
||||
func antigravityMaxRetriesForModel(model string, afterSwitch bool) int {
|
||||
var envKey string
|
||||
if strings.HasPrefix(model, "claude-") {
|
||||
envKey = antigravityMaxRetriesClaudeEnv
|
||||
} else if isImageGenerationModel(model) {
|
||||
envKey = antigravityMaxRetriesGeminiImageEnv
|
||||
} else if strings.HasPrefix(model, "gemini-") {
|
||||
envKey = antigravityMaxRetriesGeminiTextEnv
|
||||
}
|
||||
|
||||
if envKey != "" {
|
||||
if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" {
|
||||
if value, err := strconv.Atoi(raw); err == nil && value > 0 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
if afterSwitch {
|
||||
return antigravityMaxRetriesAfterSwitch()
|
||||
}
|
||||
return antigravityMaxRetries()
|
||||
}
|
||||
|
||||
func antigravityUseMappedModelForBilling() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv)))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func antigravityFallbackCooldownSeconds() (time.Duration, bool) {
|
||||
raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv))
|
||||
if raw == "" {
|
||||
return 0, false
|
||||
}
|
||||
seconds, err := strconv.Atoi(raw)
|
||||
if err != nil || seconds <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return time.Duration(seconds) * time.Second, true
|
||||
}
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
@@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre
|
||||
fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes
|
||||
}
|
||||
defaultDur := time.Duration(fallbackMinutes) * time.Minute
|
||||
if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok {
|
||||
defaultDur = fallbackDur
|
||||
}
|
||||
ra := time.Now().Add(defaultDur)
|
||||
if useScopeLimit {
|
||||
log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur)
|
||||
@@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
|
||||
return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error {
|
||||
return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body)
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error {
|
||||
statusStr := "UNKNOWN"
|
||||
switch status {
|
||||
@@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) {
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息
|
||||
// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误
|
||||
func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contents, ok := payload["contents"].([]any)
|
||||
if !ok || len(contents) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
filtered := make([]any, 0, len(contents))
|
||||
modified := false
|
||||
|
||||
for _, c := range contents {
|
||||
contentMap, ok := c.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, c)
|
||||
continue
|
||||
}
|
||||
|
||||
parts, hasParts := contentMap["parts"]
|
||||
if !hasParts {
|
||||
filtered = append(filtered, c)
|
||||
continue
|
||||
}
|
||||
|
||||
partsSlice, ok := parts.([]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, c)
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过 parts 为空数组的消息
|
||||
if len(partsSlice) == 0 {
|
||||
modified = true
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
payload["contents"] = filtered
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||
}
|
||||
|
||||
func TestIsPromptTooLongError(t *testing.T) {
|
||||
require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`)))
|
||||
require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`)))
|
||||
require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`)))
|
||||
}
|
||||
|
||||
type httpUpstreamStub struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
writer := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(writer)
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"model": "claude-opus-4-5",
|
||||
"messages": []map[string]any{
|
||||
{"role": "user", "content": "hi"},
|
||||
},
|
||||
"max_tokens": 1,
|
||||
"stream": false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||
c.Request = req
|
||||
|
||||
respBody := []byte(`{"error":{"message":"Prompt is too long"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Header: http.Header{"X-Request-Id": []string{"req-1"}},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{
|
||||
tokenProvider: &AntigravityTokenProvider{},
|
||||
httpUpstream: &httpUpstreamStub{resp: resp},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.Forward(context.Background(), c, account, body)
|
||||
require.Nil(t, result)
|
||||
|
||||
var promptErr *PromptTooLongError
|
||||
require.ErrorAs(t, err, &promptErr)
|
||||
require.Equal(t, http.StatusBadRequest, promptErr.StatusCode)
|
||||
require.Equal(t, "req-1", promptErr.RequestID)
|
||||
require.NotEmpty(t, promptErr.Body)
|
||||
|
||||
raw, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := raw.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, "prompt_too_long", events[0].Kind)
|
||||
}
|
||||
|
||||
func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) {
|
||||
t.Setenv(antigravityMaxRetriesEnv, "4")
|
||||
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7")
|
||||
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
|
||||
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
|
||||
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
|
||||
|
||||
got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false)
|
||||
require.Equal(t, 4, got)
|
||||
|
||||
got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true)
|
||||
require.Equal(t, 7, got)
|
||||
}
|
||||
|
||||
func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) {
|
||||
t.Setenv(antigravityMaxRetriesEnv, "5")
|
||||
t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "")
|
||||
t.Setenv(antigravityMaxRetriesClaudeEnv, "")
|
||||
t.Setenv(antigravityMaxRetriesGeminiTextEnv, "")
|
||||
t.Setenv(antigravityMaxRetriesGeminiImageEnv, "")
|
||||
|
||||
got := antigravityMaxRetriesForModel("gemini-2.5-flash", true)
|
||||
require.Equal(t, 5, got)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -16,6 +17,21 @@ const (
|
||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||
)
|
||||
|
||||
// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中
|
||||
func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool {
|
||||
if len(supportedScopes) == 0 {
|
||||
// 未配置时默认全部支持
|
||||
return true
|
||||
}
|
||||
supported := slices.Contains(supportedScopes, string(scope))
|
||||
return supported
|
||||
}
|
||||
|
||||
// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本)
|
||||
func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
return resolveAntigravityQuotaScope(requestedModel)
|
||||
}
|
||||
|
||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
model := normalizeAntigravityModelName(requestedModel)
|
||||
|
||||
@@ -2,6 +2,14 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
// API Key status constants
|
||||
const (
|
||||
StatusAPIKeyActive = "active"
|
||||
StatusAPIKeyDisabled = "disabled"
|
||||
StatusAPIKeyQuotaExhausted = "quota_exhausted"
|
||||
StatusAPIKeyExpired = "expired"
|
||||
)
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
@@ -15,8 +23,53 @@ type APIKey struct {
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
|
||||
// Quota fields
|
||||
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 // Used quota amount
|
||||
ExpiresAt *time.Time // Expiration time (nil = never expires)
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
// IsExpired checks if the API key has expired
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*k.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the API key quota is exhausted
|
||||
func (k *APIKey) IsQuotaExhausted() bool {
|
||||
if k.Quota <= 0 {
|
||||
return false // unlimited
|
||||
}
|
||||
return k.QuotaUsed >= k.Quota
|
||||
}
|
||||
|
||||
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
|
||||
func (k *APIKey) GetQuotaRemaining() float64 {
|
||||
if k.Quota <= 0 {
|
||||
return -1 // unlimited
|
||||
}
|
||||
remaining := k.Quota - k.QuotaUsed
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
|
||||
func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
if k.ExpiresAt == nil {
|
||||
return -1 // never expires
|
||||
}
|
||||
duration := time.Until(*k.ExpiresAt)
|
||||
if duration < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||
type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
@@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct {
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
User APIKeyAuthUserSnapshot `json:"user"`
|
||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||
|
||||
// Quota fields for API Key independent quota feature
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 `json:"quota_used"` // Used quota amount
|
||||
|
||||
// Expiration field for API Key expiration feature
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
@@ -23,29 +32,34 @@ type APIKeyAuthUserSnapshot struct {
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
Quota: apiKey.Quota,
|
||||
QuotaUsed: apiKey.QuotaUsed,
|
||||
ExpiresAt: apiKey.ExpiresAt,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
@@ -223,26 +226,29 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -260,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
Quota: snapshot.Quota,
|
||||
QuotaUsed: snapshot.QuotaUsed,
|
||||
ExpiresAt: snapshot.ExpiresAt,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
@@ -270,27 +279,30 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
|
||||
@@ -24,6 +24,10 @@ var (
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
|
||||
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
|
||||
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
|
||||
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -51,6 +55,9 @@ type APIKeyRepository interface {
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||
|
||||
// Quota methods
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct {
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
|
||||
// Quota fields
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
@@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct {
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
|
||||
// Quota fields
|
||||
Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited)
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
|
||||
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
|
||||
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
@@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
QuotaUsed: 0,
|
||||
}
|
||||
|
||||
// Set expiration time if specified
|
||||
if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 {
|
||||
expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays)
|
||||
apiKey.ExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
@@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
}
|
||||
}
|
||||
|
||||
// Update quota fields
|
||||
if req.Quota != nil {
|
||||
apiKey.Quota = *req.Quota
|
||||
// If quota is increased and status was quota_exhausted, reactivate
|
||||
if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
if req.ResetQuota != nil && *req.ResetQuota {
|
||||
apiKey.QuotaUsed = 0
|
||||
// If resetting quota and status was quota_exhausted, reactivate
|
||||
if apiKey.Status == StatusAPIKeyQuotaExhausted {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
if req.ClearExpiration {
|
||||
apiKey.ExpiresAt = nil
|
||||
// If clearing expiry and status was expired, reactivate
|
||||
if apiKey.Status == StatusAPIKeyExpired {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
} else if req.ExpiresAt != nil {
|
||||
apiKey.ExpiresAt = req.ExpiresAt
|
||||
// If extending expiry and status was expired, reactivate
|
||||
if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
@@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||
// Returns nil if valid, error if invalid
|
||||
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||
// Check expiration
|
||||
if apiKey.IsExpired() {
|
||||
return ErrAPIKeyExpired
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
return ErrAPIKeyQuotaExhausted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateQuotaUsed updates the quota_used field after a request
|
||||
// Also checks if quota is exhausted and updates status accordingly
|
||||
func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use repository to atomically increment quota_used
|
||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("increment quota used: %w", err)
|
||||
}
|
||||
|
||||
// Check if quota is now exhausted and update status if needed
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID)
|
||||
if err != nil {
|
||||
return nil // Don't fail the request, just log
|
||||
}
|
||||
|
||||
// If quota is set and now exhausted, update status
|
||||
if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota {
|
||||
apiKey.Status = StatusAPIKeyQuotaExhausted
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil // Don't fail the request
|
||||
}
|
||||
// Invalidate cache so next request sees the new status
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]
|
||||
return s.listKeysByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
|
||||
@@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) (
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
|
||||
@@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供且功能已启用)
|
||||
if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号
|
||||
AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游)
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
|
||||
@@ -257,6 +257,9 @@ var (
|
||||
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||
|
||||
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
|
||||
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string {
|
||||
}
|
||||
|
||||
// replaceModelInBody 替换请求体中的model字段
|
||||
// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改
|
||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||
var req map[string]any
|
||||
var req map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
req["model"] = newModel
|
||||
// 只序列化 model 字段
|
||||
modelBytes, err := json.Marshal(newModel)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
req["model"] = modelBytes
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
@@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
if len(body) == 0 {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
// 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改
|
||||
var reqRaw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &reqRaw); err != nil {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
// 同时解析为 map[string]any 用于修改非 messages 字段
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, modelID, nil
|
||||
}
|
||||
|
||||
toolNameMap := make(map[string]string)
|
||||
modified := false
|
||||
|
||||
if system, ok := req["system"]; ok {
|
||||
switch v := system.(type) {
|
||||
@@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
sanitized := sanitizeSystemText(v)
|
||||
if sanitized != v {
|
||||
req["system"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
@@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
sanitized := sanitizeSystemText(text)
|
||||
if sanitized != text {
|
||||
block["text"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
if normalized != rawModel {
|
||||
req["model"] = normalized
|
||||
modelID = normalized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
toolMap["name"] = normalized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if desc, ok := toolMap["description"].(string); ok {
|
||||
sanitized := sanitizeToolDescription(desc)
|
||||
if sanitized != desc {
|
||||
toolMap["description"] = sanitized
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if schema, ok := toolMap["input_schema"]; ok {
|
||||
normalizeToolInputSchema(schema, toolNameMap)
|
||||
modified = true
|
||||
}
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
@@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
normalizedTools[normalized] = value
|
||||
}
|
||||
req["tools"] = normalizedTools
|
||||
modified = true
|
||||
}
|
||||
} else {
|
||||
req["tools"] = []any{}
|
||||
modified = true
|
||||
}
|
||||
|
||||
// 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节
|
||||
messagesModified := false
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
@@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// 检查此消息是否包含 thinking 块
|
||||
hasThinking := false
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// 如果包含 thinking 块,跳过此消息的修改
|
||||
if hasThinking {
|
||||
continue
|
||||
}
|
||||
// 只修改不包含 thinking 块的消息中的 tool_use
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
@@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
normalized := normalizeToolNameForClaude(name, toolNameMap)
|
||||
if normalized != "" && normalized != name {
|
||||
blockMap["name"] = normalized
|
||||
messagesModified = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
if opts.stripSystemCacheControl {
|
||||
if system, ok := req["system"]; ok {
|
||||
_ = stripCacheControlFromSystemBlocks(system)
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
|
||||
}
|
||||
if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
|
||||
metadata["user_id"] = opts.metadataUserID
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
delete(req, "temperature")
|
||||
delete(req, "tool_choice")
|
||||
if _, hasTemp := req["temperature"]; hasTemp {
|
||||
delete(req, "temperature")
|
||||
modified = true
|
||||
}
|
||||
if _, hasChoice := req["tool_choice"]; hasChoice {
|
||||
delete(req, "tool_choice")
|
||||
modified = true
|
||||
}
|
||||
|
||||
if !modified && !messagesModified {
|
||||
return body, modelID, toolNameMap
|
||||
}
|
||||
|
||||
// 如果 messages 没有被修改,保留原始 messages 字节
|
||||
if !messagesModified {
|
||||
// 序列化非 messages 字段
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
}
|
||||
// 替换回原始的 messages
|
||||
var newReq map[string]json.RawMessage
|
||||
if err := json.Unmarshal(newBody, &newReq); err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
if origMessages, ok := reqRaw["messages"]; ok {
|
||||
newReq["messages"] = origMessages
|
||||
}
|
||||
finalBody, err := json.Marshal(newReq)
|
||||
if err != nil {
|
||||
return newBody, modelID, toolNameMap
|
||||
}
|
||||
return finalBody, modelID, toolNameMap
|
||||
}
|
||||
|
||||
// messages 被修改了,需要完整序列化
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body, modelID, toolNameMap
|
||||
@@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
return s.resolveGroupByID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||
return nil
|
||||
@@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
@@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) {
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
|
||||
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
|
||||
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
@@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
// Antigravity 平台使用专门的模型支持检查
|
||||
return IsAntigravityModelSupported(requestedModel)
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
}
|
||||
// Gemini API Key 账户直接透传,由上游判断模型是否支持
|
||||
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
|
||||
return true
|
||||
@@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// 强制执行 cache_control 块数量限制(最多 4 个)
|
||||
body = enforceCacheControlLimit(body)
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
token, tokenType, err := s.GetAccessToken(ctx, account)
|
||||
@@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测 thinking block 被修改的错误
|
||||
// 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified"
|
||||
if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
log.Printf("[SignatureCheck] Detected thinking block modification error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
|
||||
// 例如: "all messages must have non-empty content"
|
||||
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
|
||||
@@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
|
||||
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额
|
||||
}
|
||||
|
||||
// APIKeyQuotaUpdater defines the interface for updating API Key quota
|
||||
type APIKeyQuotaUpdater interface {
|
||||
UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 API Key 配额(如果设置了配额限制)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
@@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct {
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
LongContextThreshold int // 长上下文阈值(如 200000)
|
||||
LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0)
|
||||
APIKeyService *APIKeyService // API Key 配额服务(可选)
|
||||
}
|
||||
|
||||
// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini)
|
||||
@@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
// API Key 独立配额扣费
|
||||
if input.APIKeyService != nil && apiKey.Quota > 0 {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Add API key quota used failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
// 应用模型映射:
|
||||
// - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名
|
||||
// - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if reqModel != "" {
|
||||
mappedModel := reqModel
|
||||
mappingSource := ""
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
mappingSource = "account"
|
||||
}
|
||||
}
|
||||
if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
normalized := claude.NormalizeModelID(reqModel)
|
||||
if normalized != reqModel {
|
||||
mappedModel = normalized
|
||||
mappingSource = "prefix"
|
||||
}
|
||||
}
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取凭证
|
||||
@@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
|
||||
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
|
||||
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return nil // 无法解析 scope,跳过检查
|
||||
}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil // 查询失败时放行
|
||||
}
|
||||
if group == nil {
|
||||
return nil // 分组不存在时放行
|
||||
}
|
||||
|
||||
if !IsScopeSupported(group.SupportedModelScopes, scope) {
|
||||
return ErrModelScopeNotSupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableModels returns the list of models available for a group
|
||||
// It aggregates model_mapping keys from all schedulable accounts in the group
|
||||
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
|
||||
|
||||
@@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
}
|
||||
|
||||
// 过滤掉 parts 为空的消息(Gemini API 不接受空 parts)
|
||||
if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil {
|
||||
body = filteredBody
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent", "countTokens":
|
||||
// ok
|
||||
|
||||
@@ -2,20 +2,22 @@ package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中替换 thoughtSignature 字段为 dummy 签名,
|
||||
// 以避免跨账号签名验证错误。
|
||||
//
|
||||
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
||||
// 会导致新账号的签名验证失败。通过替换为 dummy 签名,跳过签名验证。
|
||||
//
|
||||
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
||||
// to avoid cross-account signature validation errors.
|
||||
// CleanGeminiNativeThoughtSignatures replaces thoughtSignature fields with dummy signature
|
||||
// in Gemini native API requests to avoid cross-account signature validation errors.
|
||||
//
|
||||
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||
// By removing these signatures, we allow the new account to generate valid signatures.
|
||||
// By replacing with dummy signature, we skip signature validation.
|
||||
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
@@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// 递归清理 thoughtSignature
|
||||
cleaned := cleanThoughtSignaturesRecursive(data)
|
||||
// 递归替换 thoughtSignature 为 dummy 签名
|
||||
replaced := replaceThoughtSignaturesRecursive(data)
|
||||
|
||||
// 重新序列化
|
||||
result, err := json.Marshal(cleaned)
|
||||
result, err := json.Marshal(replaced)
|
||||
if err != nil {
|
||||
// 如果序列化失败,返回原始 body
|
||||
return body
|
||||
@@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
||||
func cleanThoughtSignaturesRecursive(data any) any {
|
||||
// replaceThoughtSignaturesRecursive 递归遍历数据结构,将所有 thoughtSignature 字段替换为 dummy 签名
|
||||
func replaceThoughtSignaturesRecursive(data any) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
// 创建新的 map,移除 thoughtSignature
|
||||
// 创建新的 map,替换 thoughtSignature 为 dummy 签名
|
||||
result := make(map[string]any, len(v))
|
||||
for key, value := range v {
|
||||
// 跳过 thoughtSignature 字段
|
||||
// 替换 thoughtSignature 字段为 dummy 签名
|
||||
if key == "thoughtSignature" {
|
||||
result[key] = antigravity.DummyThoughtSignature
|
||||
continue
|
||||
}
|
||||
// 递归处理嵌套结构
|
||||
result[key] = cleanThoughtSignaturesRecursive(value)
|
||||
result[key] = replaceThoughtSignaturesRecursive(value)
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any {
|
||||
// 递归处理数组中的每个元素
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = cleanThoughtSignaturesRecursive(item)
|
||||
result[i] = replaceThoughtSignaturesRecursive(item)
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ type Group struct {
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
// 无效请求兜底分组(仅 anthropic 平台使用)
|
||||
FallbackGroupIDOnInvalidRequest *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
@@ -42,6 +44,13 @@ type Group struct {
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool
|
||||
|
||||
// MCP XML 协议注入开关(仅 antigravity 平台使用)
|
||||
MCPXMLInject bool
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
// 可选值: claude, gemini_text, gemini_image
|
||||
SupportedModelScopes []string
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
|
||||
@@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
// RewriteUserID 重写body中的metadata.user_id
|
||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
||||
//
|
||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var reqMap map[string]any
|
||||
// 使用 RawMessage 保留其他字段的原始字节
|
||||
var reqMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
// 解析 metadata 字段
|
||||
metadataRaw, ok := reqMap["metadata"]
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return body, nil
|
||||
@@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
// 只重新序列化 metadata 字段
|
||||
newMetadataRaw, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return body, nil
|
||||
}
|
||||
reqMap["metadata"] = newMetadataRaw
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
@@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
||||
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
||||
//
|
||||
// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节,
|
||||
// 避免重新序列化导致 thinking 块等内容被修改。
|
||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
// 先执行常规的 RewriteUserID 逻辑
|
||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
||||
@@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 解析重写后的 body,提取 user_id
|
||||
var reqMap map[string]any
|
||||
// 使用 RawMessage 保留其他字段的原始字节
|
||||
var reqMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
// 解析 metadata 字段
|
||||
metadataRaw, ok := reqMap["metadata"]
|
||||
if !ok {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(metadataRaw, &metadata); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return newBody, nil
|
||||
@@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b
|
||||
)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
// 只重新序列化 metadata 字段
|
||||
newMetadataRaw, marshalErr := json.Marshal(metadata)
|
||||
if marshalErr != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
reqMap["metadata"] = newMetadataRaw
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ type opencodeCacheMetadata struct {
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||
func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||
@@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
result.Modified = true
|
||||
}
|
||||
} else if existingInstructions == "" {
|
||||
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions != "" {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
result.Modified = true
|
||||
}
|
||||
// instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法
|
||||
if applyInstructions(reqBody, isCodexCLI) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||
@@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string {
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
// applyInstructions 处理 instructions 字段
|
||||
// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令)
|
||||
// isCodexCLI=false: 优先使用 opencode 指令覆盖
|
||||
func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool {
|
||||
if isCodexCLI {
|
||||
return applyCodexCLIInstructions(reqBody)
|
||||
}
|
||||
return applyOpenCodeInstructions(reqBody)
|
||||
}
|
||||
|
||||
// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions
|
||||
// 仅在 instructions 为空时添加 opencode 指令
|
||||
func applyCodexCLIInstructions(reqBody map[string]any) bool {
|
||||
if !isInstructionsEmpty(reqBody) {
|
||||
return false // 已有有效 instructions,不修改
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
if instructions != "" {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令
|
||||
// 优先使用 opencode 指令覆盖
|
||||
func applyOpenCodeInstructions(reqBody map[string]any) bool {
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
return true
|
||||
}
|
||||
} else if existingInstructions == "" {
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions != "" {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isInstructionsEmpty 检查 instructions 字段是否为空
|
||||
// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串
|
||||
func isInstructionsEmpty(reqBody map[string]any) bool {
|
||||
val, exists := reqBody["instructions"]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
return strings.TrimSpace(str) == ""
|
||||
}
|
||||
|
||||
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
// 未显式设置 store=true,默认为 false。
|
||||
store, ok := reqBody["store"].(bool)
|
||||
@@ -59,7 +59,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
@@ -79,7 +79,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
@@ -97,7 +97,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
@@ -148,7 +148,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
@@ -169,7 +169,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
applyCodexOAuthTransform(reqBody, false)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
@@ -196,3 +196,77 @@ func setupCodexCache(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) {
|
||||
// Codex CLI 场景:已有 instructions 时不修改
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"instructions": "existing instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "existing instructions", instructions)
|
||||
// Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变
|
||||
_ = result
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) {
|
||||
// Codex CLI 场景:无 instructions 时补充默认值
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
// 没有 instructions 字段
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, instructions)
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) {
|
||||
// 非 Codex CLI 场景:使用 opencode 指令覆盖
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"instructions": "old instructions",
|
||||
}
|
||||
|
||||
result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false
|
||||
|
||||
instructions, ok := reqBody["instructions"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEqual(t, "old instructions", instructions)
|
||||
require.True(t, result.Modified)
|
||||
}
|
||||
|
||||
func TestIsInstructionsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqBody map[string]any
|
||||
expected bool
|
||||
}{
|
||||
{"missing field", map[string]any{}, true},
|
||||
{"nil value", map[string]any{"instructions": nil}, true},
|
||||
{"empty string", map[string]any{"instructions": ""}, true},
|
||||
{"whitespace only", map[string]any{"instructions": " "}, true},
|
||||
{"non-string", map[string]any{"instructions": 123}, true},
|
||||
{"valid string", map[string]any{"instructions": "hello"}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isInstructionsEmpty(tt.reqBody)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -796,8 +796,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||
codexResult := applyCodexOAuthTransform(reqBody)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI)
|
||||
if codexResult.Modified {
|
||||
bodyModified = true
|
||||
}
|
||||
@@ -1681,13 +1681,14 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
UserAgent string // 请求的 User-Agent
|
||||
IPAddress string // 请求的客户端 IP 地址
|
||||
APIKeyService APIKeyQuotaUpdater
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -1799,6 +1800,13 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
}
|
||||
}
|
||||
|
||||
// Update API key quota if applicable (only for balance mode with quota set)
|
||||
if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil {
|
||||
if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update API key quota failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule batch update for account last_used_at
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
|
||||
|
||||
@@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
||||
return fmt.Errorf("query error counts: %w", err)
|
||||
}
|
||||
|
||||
accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query account switch counts: %w", err)
|
||||
}
|
||||
|
||||
windowSeconds := windowEnd.Sub(windowStart).Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 60
|
||||
@@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
|
||||
TokenConsumed: tokenConsumed,
|
||||
QPS: float64Ptr(roundTo1DP(qps)),
|
||||
TPS: float64Ptr(roundTo1DP(tps)),
|
||||
TokenConsumed: tokenConsumed,
|
||||
AccountSwitchCount: accountSwitchCount,
|
||||
QPS: float64Ptr(roundTo1DP(qps)),
|
||||
TPS: float64Ptr(roundTo1DP(tps)),
|
||||
|
||||
DurationP50Ms: duration.p50,
|
||||
DurationP90Ms: duration.p90,
|
||||
@@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2`
|
||||
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) {
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(SUM(CASE
|
||||
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
|
||||
ELSE 0
|
||||
END), 0) AS switch_count
|
||||
FROM ops_error_logs o
|
||||
CROSS JOIN LATERAL jsonb_array_elements(
|
||||
COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb)
|
||||
) AS ev
|
||||
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||
AND o.is_count_tokens = FALSE`
|
||||
|
||||
var count int64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
type opsCollectedSystemStats struct {
|
||||
cpuUsagePercent *float64
|
||||
memoryUsedMB *int64
|
||||
|
||||
@@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct {
|
||||
Upstream429Count int64
|
||||
Upstream529Count int64
|
||||
|
||||
TokenConsumed int64
|
||||
TokenConsumed int64
|
||||
AccountSwitchCount int64
|
||||
|
||||
QPS *float64
|
||||
TPS *float64
|
||||
@@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct {
|
||||
DBConnIdle *int `json:"db_conn_idle"`
|
||||
DBConnWaiting *int `json:"db_conn_waiting"`
|
||||
|
||||
GoroutineCount *int `json:"goroutine_count"`
|
||||
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
||||
GoroutineCount *int `json:"goroutine_count"`
|
||||
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
||||
AccountSwitchCount *int64 `json:"account_switch_count"`
|
||||
}
|
||||
|
||||
type OpsUpsertJobHeartbeatInput struct {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lib/pq"
|
||||
@@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq
|
||||
continue
|
||||
}
|
||||
|
||||
attemptCtx := ctx
|
||||
if switches > 0 {
|
||||
attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches)
|
||||
}
|
||||
exec := func() *opsRetryExecution {
|
||||
defer selection.ReleaseFunc()
|
||||
return s.executeWithAccount(ctx, reqType, errorLog, body, account)
|
||||
return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account)
|
||||
}()
|
||||
|
||||
if exec != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ type OpsThroughputTrendPoint struct {
|
||||
BucketStart time.Time `json:"bucket_start"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
SwitchCount int64 `json:"switch_count"`
|
||||
QPS float64 `json:"qps"`
|
||||
TPS float64 `json:"tps"`
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ type UserRepository interface {
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
|
||||
// TOTP 相关方法
|
||||
// TOTP 双因素认证
|
||||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||||
EnableTotp(ctx context.Context, userID int64) error
|
||||
DisableTotp(ctx context.Context, userID int64) error
|
||||
|
||||
Reference in New Issue
Block a user