diff --git a/backend/ent/group.go b/backend/ent/group.go index 0d0c0538..f91a4079 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -56,6 +56,8 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + // 无效请求兜底使用的分组 ID + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` // 是否启用模型路由配置 @@ -172,7 +174,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -322,6 +324,13 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupID = new(int64) *_m.FallbackGroupID = value.Int64 } + case group.FieldFallbackGroupIDOnInvalidRequest: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i]) + } else if value.Valid { + _m.FallbackGroupIDOnInvalidRequest = new(int64) + *_m.FallbackGroupIDOnInvalidRequest = value.Int64 + } case group.FieldModelRouting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field model_routing", values[i]) @@ -487,6 +496,11 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.FallbackGroupIDOnInvalidRequest; v != nil { + builder.WriteString("fallback_group_id_on_invalid_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("model_routing=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index d66d3edc..b63827d3 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -53,6 +53,8 @@ const ( FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. FieldFallbackGroupID = "fallback_group_id" + // FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database. + FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request" // FieldModelRouting holds the string denoting the model_routing field in the database. FieldModelRouting = "model_routing" // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. @@ -151,6 +153,7 @@ var Columns = []string{ FieldImagePrice4k, FieldClaudeCodeOnly, FieldFallbackGroupID, + FieldFallbackGroupIDOnInvalidRequest, FieldModelRouting, FieldModelRoutingEnabled, } @@ -317,6 +320,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() } +// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field. +func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc() +} + // ByModelRoutingEnabled orders the results by the model_routing_enabled field. func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 6ce9e4c6..02cbb3d5 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) } +// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ. +func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. func ModelRoutingEnabled(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) @@ -1070,6 +1075,56 @@ func FallbackGroupIDNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) } +// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest)) +} + +// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest)) +} + // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. func ModelRoutingIsNil() predicate.Group { return predicate.Group(sql.FieldIsNull(FieldModelRouting)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 0f251e0b..b08894da 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -286,6 +286,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { return _c } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _c +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _c +} + // SetModelRouting sets the "model_routing" field. func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { _c.mutation.SetModelRouting(v) @@ -640,6 +654,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _node.FallbackGroupID = &value } + if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + _node.FallbackGroupIDOnInvalidRequest = &value + } if value, ok := _c.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _node.ModelRouting = value @@ -1128,6 +1146,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { return u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { u.Set(group.FieldModelRouting, v) @@ -1581,6 +1623,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2205,6 +2275,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index c3cc2708..ce8f3748 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -395,6 +395,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { _u.mutation.SetModelRouting(v) @@ -829,6 +856,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -1513,6 +1549,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { _u.mutation.SetModelRouting(v) @@ -1977,6 +2040,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index b377804f..5624c05b 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -226,6 +226,7 @@ var ( {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index cd2fe8e0..69801b9f 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3833,61 +3833,63 @@ func (m *AccountGroupMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -4976,6 +4978,76 @@ func (m *GroupMutation) ResetFallbackGroupID() { delete(m.clearedFields, group.FieldFallbackGroupID) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil +} + +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i + } +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + // SetModelRouting sets the "model_routing" field. func (m *GroupMutation) SetModelRouting(value map[string][]int64) { m.model_routing = &value @@ -5419,7 +5491,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 22) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5477,6 +5549,9 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.model_routing != nil { fields = append(fields, group.FieldModelRouting) } @@ -5529,6 +5604,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() case group.FieldModelRouting: return m.ModelRouting() case group.FieldModelRoutingEnabled: @@ -5580,6 +5657,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) case group.FieldModelRouting: return m.OldModelRouting(ctx) case group.FieldModelRoutingEnabled: @@ -5726,6 +5805,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil case group.FieldModelRouting: v, ok := value.(map[string][]int64) if !ok { @@ -5775,6 +5861,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } return fields } @@ -5801,6 +5890,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice4k() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() } return nil, false } @@ -5873,6 +5964,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -5908,6 +6006,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.FieldCleared(group.FieldModelRouting) { fields = append(fields, group.FieldModelRouting) } @@ -5952,6 +6053,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ClearModelRouting() return nil @@ -6020,6 +6124,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupID: m.ResetFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ResetModelRouting() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 0cb10775..3ddb206d 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -281,7 +281,7 @@ func init() { // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[17].Descriptor() + groupDescModelRoutingEnabled := groupFields[18].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) promocodeFields := schema.PromoCode{}.Fields() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 5d0a1e9a..51cae1a6 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -95,6 +95,10 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). Comment("非 Claude Code 请求降级使用的分组 ID"), + field.Int64("fallback_group_id_on_invalid_request"). + Optional(). + Nillable(). + Comment("无效请求兜底使用的分组 ID"), // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index f6780dee..8229a780 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,11 +35,12 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` @@ -58,11 +59,12 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` @@ -155,22 +157,23 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) @@ -196,23 +199,24 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f5bdd008..f1991c30 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -73,27 +73,28 @@ func GroupFromServiceShallow(g *service.Group) *Group { return nil } return &Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, - AccountCount: g.AccountCount, + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, + AccountCount: g.AccountCount, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 4519143c..b425523b 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -57,6 +57,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 6c8d9ebe..cd622a3b 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -325,136 +326,186 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + currentAPIKey := apiKey + currentSubscription := subscription + var fallbackGroupID *int64 + if apiKey.Group != nil { + fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest + } + fallbackUsed := false for { - // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) - if err != nil { - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID) + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + retryWithFallback := false - // 检查预热请求拦截(在账号选择后、转发前检查) - if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - if reqStream { - sendMockWarmupStream(c, reqModel) - } else { - sendMockWarmupResponse(c, reqModel) - } - return - } - - // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + for { + // 选择支持该模型的账号 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } - } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - // 转发请求 - 根据账号平台分流 - var result *service.ForwardResult - if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) - } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) - } - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + // 检查预热请求拦截(在账号选择后、转发前检查) + if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockWarmupStream(c, reqModel) + } else { + sendMockWarmupResponse(c, reqModel) + } + return + } + + // 3. 获取账号并发槽位 + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) + } else { + result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) + } + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var promptTooLongErr *service.PromptTooLongError + if errors.As(err, &promptTooLongErr) { + log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { + fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) + if err != nil { + log.Printf("Resolve fallback group failed: %v", err) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + if fallbackGroup.Platform != service.PlatformAnthropic || + fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || + fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") + c.Request = c.Request.WithContext(ctx) + currentAPIKey = fallbackAPIKey + currentSubscription = nil + fallbackUsed = true + retryWithFallback = true + break + } + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + lastFailoverStatus = failoverErr.StatusCode + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: usedAccount, + Subscription: currentSubscription, + UserAgent: ua, + IPAddress: clientIP, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP) return } - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account, userAgent, clientIP) - return + if !retryWithFallback { + return + } } } @@ -518,6 +569,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { }) } +func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey { + if apiKey == nil || group == nil { + return apiKey + } + cloned := *apiKey + groupID := group.ID + cloned.GroupID = &groupID + cloned.Group = group + return &cloned +} + // Usage handles getting account balance for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ab890844..9938a36d 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -136,6 +136,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice4k, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, + group.FieldFallbackGroupIDOnInvalidRequest, group.FieldModelRoutingEnabled, group.FieldModelRouting, ) @@ -406,28 +407,29 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5c4d6cf4..f207f479 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -50,6 +50,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). + SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) // 设置模型路由配置 @@ -116,6 +117,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } else { builder = builder.ClearFallbackGroupID() } + // 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置 + if groupIn.FallbackGroupIDOnInvalidRequest != nil { + builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest) + } else { + builder = builder.ClearFallbackGroupIDOnInvalidRequest() + } // 处理 ModelRouting:nil 时清除,否则设置 if groupIn.ModelRouting != nil { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1b2c7ff4..12f01810 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -108,6 +108,8 @@ type CreateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 @@ -130,6 +132,8 @@ type UpdateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 @@ -572,24 +576,35 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return nil, err } } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -651,6 +666,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro } } +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") + } + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") + } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") + } + + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") + } + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -717,6 +763,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.FallbackGroupID = nil } } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest // 模型路由配置 if input.ModelRouting != nil { diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index e0574e2e..1454dccd 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -378,3 +378,374 @@ func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { panic("unexpected DeleteAccountGroupsByGroupID call") } + +type groupRepoStubForInvalidRequestFallback struct { + groups map[int64]*Group + created *Group + updated *Group +} + +func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error { + s.created = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error { + s.updated = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) { + return s.GetByIDLite(ctx, id) +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) { + if g, ok := s.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformOpenAI, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeSubscription, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + tests := []struct { + name string + fallback *Group + wantMessage string + }{ + { + name: "openai_target", + fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "antigravity_target", + fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "subscription_group", + fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + wantMessage: "fallback group cannot be subscription type", + }, + { + name: "nested_fallback", + fallback: &Group{ + ID: 10, + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(), + }, + wantMessage: "fallback group cannot have invalid request fallback configured", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fallbackID := tc.fallback.ID + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: tc.fallback, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantMessage) + require.Nil(t, repo.created) + }) + } +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group not found") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + zero := int64(0) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &zero, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + SubscriptionType: SubscriptionTypeSubscription, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + clear := int64(0) + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + FallbackGroupIDOnInvalidRequest: &clear, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group cannot be subscription type") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 85e8eec7..d3c15418 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -62,6 +62,17 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// PromptTooLongError 表示上游明确返回 prompt too long +type PromptTooLongError struct { + StatusCode int + RequestID string + Body []byte +} + +func (e *PromptTooLongError) Error() string { + return fmt.Sprintf("prompt too long: status=%d", e.StatusCode) +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() @@ -930,6 +941,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { + if resp.StatusCode == http.StatusBadRequest { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500)) + } + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody + maxBytes := 2048 + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + } + upstreamDetail := "" + if logBody { + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "prompt_too_long", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) if s.shouldFailoverUpstreamError(resp.StatusCode) { @@ -1019,21 +1063,55 @@ func isSignatureRelatedError(respBody []byte) bool { return false } +func isPromptTooLongError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if msg == "" { + msg = strings.ToLower(string(respBody)) + } + return strings.Contains(msg, "prompt is too long") +} + func extractAntigravityErrorMessage(body []byte) string { var payload map[string]any if err := json.Unmarshal(body, &payload); err != nil { return "" } + parseNestedMessage := func(msg string) string { + trimmed := strings.TrimSpace(msg) + if trimmed == "" || !strings.HasPrefix(trimmed, "{") { + return "" + } + var nested map[string]any + if err := json.Unmarshal([]byte(trimmed), &nested); err != nil { + return "" + } + if errObj, ok := nested["error"].(map[string]any); ok { + if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + return "" + } + // Google-style: {"error": {"message": "..."}} if errObj, ok := payload["error"].(map[string]any); ok { if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } } // Fallback: top-level message if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } @@ -2209,6 +2287,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) } +func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { + return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body) +} + func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { statusStr := "UNKNOWN" switch status { diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 05ad9bbd..9c1fb415 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -1,10 +1,16 @@ package service import ( + "bytes" + "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -81,3 +87,77 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "tool_use", blocks[1]["type"]) } + +func TestIsPromptTooLongError(t *testing.T) { + require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`))) + require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`))) + require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`))) +} + +type httpUpstreamStub struct { + resp *http.Response + err error +} + +func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return s.resp, s.err +} + +func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "max_tokens": 1, + "stream": false, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + respBody := []byte(`{"error":{"message":"Prompt is too long"}}`) + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"X-Request-Id": []string{"req-1"}}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.Nil(t, result) + + var promptErr *PromptTooLongError + require.ErrorAs(t, err, &promptErr) + require.Equal(t, http.StatusBadRequest, promptErr.StatusCode) + require.Equal(t, "req-1", promptErr.RequestID) + require.NotEmpty(t, promptErr.Body) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "prompt_too_long", events[0].Kind) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 5b476dbc..4b51fbbb 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -23,20 +23,21 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 521f1da5..8b74e7aa 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -207,22 +207,23 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, } } return snapshot @@ -250,23 +251,24 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, } } return apiKey diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9cb15e4a..a7ded8a9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -55,6 +55,15 @@ func shortSessionHash(sessionHash string) string { return sessionHash[:8] } +func normalizeClaudeModelForAnthropic(requestedModel string) string { + for _, prefix := range anthropicPrefixMappings { + if strings.HasPrefix(requestedModel, prefix) { + return prefix + } + } + return requestedModel +} + // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( @@ -71,6 +80,12 @@ var ( "You are a file search specialist for Claude Code", // Explore Agent 版 "You are a helpful AI assistant tasked with summarizing conversations", // Compact 版 } + + anthropicPrefixMappings = []string{ + "claude-opus-4-5", + "claude-haiku-4-5", + "claude-sonnet-4-5", + } ) // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 @@ -951,6 +966,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* return group, nil } +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) +} + func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { return nil @@ -1016,7 +1035,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID } // 强制平台模式不检查 Claude Code 限制 - if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { return nil, groupID, nil } @@ -1719,6 +1738,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo // Antigravity 平台使用专门的模型支持检查 return IsAntigravityModelSupported(requestedModel) } + if account.Platform == PlatformAnthropic { + requestedModel = normalizeClaudeModelForAnthropic(requestedModel) + } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) } @@ -2115,17 +2137,29 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) - // 应用模型映射(仅对apikey类型账号) + // 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射) originalModel := reqModel + mappedModel := reqModel + mappingSource := "" if account.Type == AccountTypeAPIKey { - mappedModel := account.GetMappedModel(reqModel) + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic { + normalized := normalizeClaudeModelForAnthropic(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) @@ -3426,16 +3460,28 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } - // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeAPIKey { - if reqModel != "" { - mappedModel := account.GetMappedModel(reqModel) + // 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic { + normalized := normalizeClaudeModelForAnthropic(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } } // 获取凭证 diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d6d1269b..9140b6d9 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -29,6 +29,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 + // 无效请求兜底分组(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置 // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") diff --git a/backend/migrations/043_add_group_invalid_request_fallback.sql b/backend/migrations/043_add_group_invalid_request_fallback.sql new file mode 100644 index 00000000..1c792704 --- /dev/null +++ b/backend/migrations/043_add_group_invalid_request_fallback.sql @@ -0,0 +1,13 @@ +-- 043_add_group_invalid_request_fallback.sql +-- 添加无效请求兜底分组配置 + +-- 添加 fallback_group_id_on_invalid_request 字段:无效请求兜底使用的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id_on_invalid_request BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id_on_invalid_request +ON groups(fallback_group_id_on_invalid_request) WHERE deleted_at IS NULL AND fallback_group_id_on_invalid_request IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.fallback_group_id_on_invalid_request IS '无效请求兜底使用的分组 ID'; diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index abbd4ff6..2a4ff1c6 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -919,6 +919,11 @@ export default { fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.', noFallback: 'No Fallback (Reject)' }, + invalidRequestFallback: { + title: 'Invalid Request Fallback Group', + hint: 'Triggered only when upstream explicitly returns prompt too long. Leave empty to disable fallback.', + noFallback: 'No Fallback' + }, modelRouting: { title: 'Model Routing', tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 1b398e7a..d8c13cfc 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -995,6 +995,11 @@ export default { fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝', noFallback: '不降级(直接拒绝)' }, + invalidRequestFallback: { + title: '无效请求兜底分组', + hint: '仅当上游明确返回 prompt too long 时才会触发,留空表示不兜底', + noFallback: '不兜底' + }, modelRouting: { title: '模型路由配置', tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 6fa57c26..fcd3748f 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -269,6 +269,7 @@ export interface Group { // Claude Code 客户端限制 claude_code_only: boolean fallback_group_id: number | null + fallback_group_id_on_invalid_request: number | null // 模型路由配置(仅 anthropic 平台使用) model_routing: Record | null model_routing_enabled: boolean @@ -322,6 +323,7 @@ export interface CreateGroupRequest { image_price_4k?: number | null claude_code_only?: boolean fallback_group_id?: number | null + fallback_group_id_on_invalid_request?: number | null } export interface UpdateGroupRequest { @@ -340,6 +342,7 @@ export interface UpdateGroupRequest { image_price_4k?: number | null claude_code_only?: boolean fallback_group_id?: number | null + fallback_group_id_on_invalid_request?: number | null } // ==================== Account & Proxy Types ==================== diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index 96457172..f3a407d7 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -460,6 +460,20 @@ + +
+ + +

{{ t('admin.groups.invalidRequestFallback.hint') }}

+
+
@@ -1202,6 +1230,44 @@ const fallbackGroupOptionsForEdit = computed(() => { return options }) +// 无效请求兜底分组选项(创建时)- 仅包含 anthropic 平台、非订阅且未配置兜底的分组 +const invalidRequestFallbackOptions = computed(() => { + const options: { value: number | null; label: string }[] = [ + { value: null, label: t('admin.groups.invalidRequestFallback.noFallback') } + ] + const eligibleGroups = groups.value.filter( + (g) => + g.platform === 'anthropic' && + g.status === 'active' && + g.subscription_type !== 'subscription' && + g.fallback_group_id_on_invalid_request === null + ) + eligibleGroups.forEach((g) => { + options.push({ value: g.id, label: g.name }) + }) + return options +}) + +// 无效请求兜底分组选项(编辑时)- 排除自身 +const invalidRequestFallbackOptionsForEdit = computed(() => { + const options: { value: number | null; label: string }[] = [ + { value: null, label: t('admin.groups.invalidRequestFallback.noFallback') } + ] + const currentId = editingGroup.value?.id + const eligibleGroups = groups.value.filter( + (g) => + g.platform === 'anthropic' && + g.status === 'active' && + g.subscription_type !== 'subscription' && + g.fallback_group_id_on_invalid_request === null && + g.id !== currentId + ) + eligibleGroups.forEach((g) => { + options.push({ value: g.id, label: g.name }) + }) + return options +}) + const groups = ref([]) const loading = ref(false) const searchQuery = ref('') @@ -1243,6 +1309,7 @@ const createForm = reactive({ // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, + fallback_group_id_on_invalid_request: null as number | null, // 模型路由开关 model_routing_enabled: false }) @@ -1414,6 +1481,7 @@ const editForm = reactive({ // Claude Code 客户端限制(仅 anthropic 平台使用) claude_code_only: false, fallback_group_id: null as number | null, + fallback_group_id_on_invalid_request: null as number | null, // 模型路由开关 model_routing_enabled: false }) @@ -1497,6 +1565,7 @@ const closeCreateModal = () => { createForm.image_price_4k = null createForm.claude_code_only = false createForm.fallback_group_id = null + createForm.fallback_group_id_on_invalid_request = null createModelRoutingRules.value = [] } @@ -1546,6 +1615,7 @@ const handleEdit = async (group: Group) => { editForm.image_price_4k = group.image_price_4k editForm.claude_code_only = group.claude_code_only || false editForm.fallback_group_id = group.fallback_group_id + editForm.fallback_group_id_on_invalid_request = group.fallback_group_id_on_invalid_request editForm.model_routing_enabled = group.model_routing_enabled || false // 加载模型路由规则(异步加载账号名称) editModelRoutingRules.value = await convertApiFormatToRoutingRules(group.model_routing) @@ -1571,6 +1641,10 @@ const handleUpdateGroup = async () => { const payload = { ...editForm, fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id, + fallback_group_id_on_invalid_request: + editForm.fallback_group_id_on_invalid_request === null + ? 0 + : editForm.fallback_group_id_on_invalid_request, model_routing: convertRoutingRulesToApiFormat(editModelRoutingRules.value) } await adminAPI.groups.update(editingGroup.value.id, payload) @@ -1612,6 +1686,16 @@ watch( if (newVal === 'subscription') { createForm.rate_multiplier = 1.0 createForm.is_exclusive = true + createForm.fallback_group_id_on_invalid_request = null + } + } +) + +watch( + () => createForm.platform, + (newVal) => { + if (!['anthropic', 'antigravity'].includes(newVal)) { + createForm.fallback_group_id_on_invalid_request = null } } )