diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 05dd1d1a..fd0c7a41 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -32,7 +32,7 @@ jobs: working-directory: backend run: | go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -severity high -confidence high ./... + gosec -conf .gosec.json -severity high -confidence high ./... frontend-security: runs-on: ubuntu-latest diff --git a/backend/.gosec.json b/backend/.gosec.json new file mode 100644 index 00000000..b34e140c --- /dev/null +++ b/backend/.gosec.json @@ -0,0 +1,5 @@ +{ + "global": { + "exclude": "G704" + } +} diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 5087e794..8b063cd5 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.76 \ No newline at end of file +0.1.83 \ No newline at end of file diff --git a/backend/ent/errorpassthroughrule.go b/backend/ent/errorpassthroughrule.go index 1932f626..62468719 100644 --- a/backend/ent/errorpassthroughrule.go +++ b/backend/ent/errorpassthroughrule.go @@ -44,6 +44,8 @@ type ErrorPassthroughRule struct { PassthroughBody bool `json:"passthrough_body,omitempty"` // CustomMessage holds the value of the "custom_message" field. CustomMessage *string `json:"custom_message,omitempty"` + // SkipMonitoring holds the value of the "skip_monitoring" field. + SkipMonitoring bool `json:"skip_monitoring,omitempty"` // Description holds the value of the "description" field. Description *string `json:"description,omitempty"` selectValues sql.SelectValues @@ -56,7 +58,7 @@ func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) { switch columns[i] { case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: values[i] = new([]byte) - case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: + case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody, errorpassthroughrule.FieldSkipMonitoring: values[i] = new(sql.NullBool) case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: values[i] = new(sql.NullInt64) @@ -171,6 +173,12 @@ func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) err _m.CustomMessage = new(string) *_m.CustomMessage = value.String } + case errorpassthroughrule.FieldSkipMonitoring: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field skip_monitoring", values[i]) + } else if value.Valid { + _m.SkipMonitoring = value.Bool + } case errorpassthroughrule.FieldDescription: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field description", values[i]) @@ -257,6 +265,9 @@ func (_m *ErrorPassthroughRule) String() string { builder.WriteString(*v) } builder.WriteString(", ") + builder.WriteString("skip_monitoring=") + builder.WriteString(fmt.Sprintf("%v", _m.SkipMonitoring)) + builder.WriteString(", ") if v := _m.Description; v != nil { builder.WriteString("description=") builder.WriteString(*v) diff --git a/backend/ent/errorpassthroughrule/errorpassthroughrule.go b/backend/ent/errorpassthroughrule/errorpassthroughrule.go index d7be4f03..859fc761 100644 --- a/backend/ent/errorpassthroughrule/errorpassthroughrule.go +++ b/backend/ent/errorpassthroughrule/errorpassthroughrule.go @@ -39,6 +39,8 @@ const ( FieldPassthroughBody = "passthrough_body" // FieldCustomMessage holds the string denoting the custom_message field in the database. FieldCustomMessage = "custom_message" + // FieldSkipMonitoring holds the string denoting the skip_monitoring field in the database. + FieldSkipMonitoring = "skip_monitoring" // FieldDescription holds the string denoting the description field in the database. FieldDescription = "description" // Table holds the table name of the errorpassthroughrule in the database. @@ -61,6 +63,7 @@ var Columns = []string{ FieldResponseCode, FieldPassthroughBody, FieldCustomMessage, + FieldSkipMonitoring, FieldDescription, } @@ -95,6 +98,8 @@ var ( DefaultPassthroughCode bool // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. DefaultPassthroughBody bool + // DefaultSkipMonitoring holds the default value on creation for the "skip_monitoring" field. + DefaultSkipMonitoring bool ) // OrderOption defines the ordering options for the ErrorPassthroughRule queries. @@ -155,6 +160,11 @@ func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() } +// BySkipMonitoring orders the results by the skip_monitoring field. +func BySkipMonitoring(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSkipMonitoring, opts...).ToFunc() +} + // ByDescription orders the results by the description field. func ByDescription(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDescription, opts...).ToFunc() diff --git a/backend/ent/errorpassthroughrule/where.go b/backend/ent/errorpassthroughrule/where.go index 56839d52..87654678 100644 --- a/backend/ent/errorpassthroughrule/where.go +++ b/backend/ent/errorpassthroughrule/where.go @@ -104,6 +104,11 @@ func CustomMessage(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) } +// SkipMonitoring applies equality check predicate on the "skip_monitoring" field. It's identical to SkipMonitoringEQ. +func SkipMonitoring(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v)) +} + // Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. func Description(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) @@ -544,6 +549,16 @@ func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) } +// SkipMonitoringEQ applies the EQ predicate on the "skip_monitoring" field. +func SkipMonitoringEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldSkipMonitoring, v)) +} + +// SkipMonitoringNEQ applies the NEQ predicate on the "skip_monitoring" field. +func SkipMonitoringNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldSkipMonitoring, v)) +} + // DescriptionEQ applies the EQ predicate on the "description" field. func DescriptionEQ(v string) predicate.ErrorPassthroughRule { return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) diff --git a/backend/ent/errorpassthroughrule_create.go b/backend/ent/errorpassthroughrule_create.go index 4dc08dce..8173936b 100644 --- a/backend/ent/errorpassthroughrule_create.go +++ b/backend/ent/errorpassthroughrule_create.go @@ -172,6 +172,20 @@ func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *Error return _c } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_c *ErrorPassthroughRuleCreate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetSkipMonitoring(v) + return _c +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetSkipMonitoring(*v) + } + return _c +} + // SetDescription sets the "description" field. func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { _c.mutation.SetDescription(v) @@ -249,6 +263,10 @@ func (_c *ErrorPassthroughRuleCreate) defaults() { v := errorpassthroughrule.DefaultPassthroughBody _c.mutation.SetPassthroughBody(v) } + if _, ok := _c.mutation.SkipMonitoring(); !ok { + v := errorpassthroughrule.DefaultSkipMonitoring + _c.mutation.SetSkipMonitoring(v) + } } // check runs all checks and user-defined validators on the builder. @@ -287,6 +305,9 @@ func (_c *ErrorPassthroughRuleCreate) check() error { if _, ok := _c.mutation.PassthroughBody(); !ok { return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} } + if _, ok := _c.mutation.SkipMonitoring(); !ok { + return &ValidationError{Name: "skip_monitoring", err: errors.New(`ent: missing required field "ErrorPassthroughRule.skip_monitoring"`)} + } return nil } @@ -366,6 +387,10 @@ func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlg _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) _node.CustomMessage = &value } + if value, ok := _c.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + _node.SkipMonitoring = value + } if value, ok := _c.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) _node.Description = &value @@ -608,6 +633,18 @@ func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleU return u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsert) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldSkipMonitoring, v) + return u +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldSkipMonitoring) + return u +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { u.Set(errorpassthroughrule.FieldDescription, v) @@ -888,6 +925,20 @@ func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRu }) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsertOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetSkipMonitoring(v) + }) +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateSkipMonitoring() + }) +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { return u.Update(func(s *ErrorPassthroughRuleUpsert) { @@ -1337,6 +1388,20 @@ func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughR }) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetSkipMonitoring(v) + }) +} + +// UpdateSkipMonitoring sets the "skip_monitoring" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateSkipMonitoring() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateSkipMonitoring() + }) +} + // SetDescription sets the "description" field. func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { return u.Update(func(s *ErrorPassthroughRuleUpsert) { diff --git a/backend/ent/errorpassthroughrule_update.go b/backend/ent/errorpassthroughrule_update.go index 9d52aa49..7e42d9fc 100644 --- a/backend/ent/errorpassthroughrule_update.go +++ b/backend/ent/errorpassthroughrule_update.go @@ -227,6 +227,20 @@ func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRule return _u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_u *ErrorPassthroughRuleUpdate) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetSkipMonitoring(v) + return _u +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetSkipMonitoring(*v) + } + return _u +} + // SetDescription sets the "description" field. func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { _u.mutation.SetDescription(v) @@ -387,6 +401,9 @@ func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, e if _u.mutation.CustomMessageCleared() { _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) } + if value, ok := _u.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + } if value, ok := _u.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) } @@ -611,6 +628,20 @@ func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughR return _u } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetSkipMonitoring(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetSkipMonitoring(v) + return _u +} + +// SetNillableSkipMonitoring sets the "skip_monitoring" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableSkipMonitoring(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetSkipMonitoring(*v) + } + return _u +} + // SetDescription sets the "description" field. func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { _u.mutation.SetDescription(v) @@ -801,6 +832,9 @@ func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *Er if _u.mutation.CustomMessageCleared() { _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) } + if value, ok := _u.mutation.SkipMonitoring(); ok { + _spec.SetField(errorpassthroughrule.FieldSkipMonitoring, field.TypeBool, value) + } if value, ok := _u.mutation.Description(); ok { _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index cfd4a72b..07f2a68e 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -325,6 +325,7 @@ var ( {Name: "response_code", Type: field.TypeInt, Nullable: true}, {Name: "passthrough_body", Type: field.TypeBool, Default: true}, {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "skip_monitoring", Type: field.TypeBool, Default: false}, {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, } // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. @@ -649,6 +650,7 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -664,31 +666,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -697,32 +699,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, }, { Name: "usagelog_model", @@ -737,12 +739,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 969d9357..34b3268e 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -5776,6 +5776,7 @@ type ErrorPassthroughRuleMutation struct { addresponse_code *int passthrough_body *bool custom_message *string + skip_monitoring *bool description *string clearedFields map[string]struct{} done bool @@ -6503,6 +6504,42 @@ func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) } +// SetSkipMonitoring sets the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { + m.skip_monitoring = &b +} + +// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. +func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { + v := m.skip_monitoring + if v == nil { + return + } + return *v, true +} + +// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) + } + return oldValue.SkipMonitoring, nil +} + +// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { + m.skip_monitoring = nil +} + // SetDescription sets the "description" field. func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { m.description = &s @@ -6586,7 +6623,7 @@ func (m *ErrorPassthroughRuleMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *ErrorPassthroughRuleMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.created_at != nil { fields = append(fields, errorpassthroughrule.FieldCreatedAt) } @@ -6626,6 +6663,9 @@ func (m *ErrorPassthroughRuleMutation) Fields() []string { if m.custom_message != nil { fields = append(fields, errorpassthroughrule.FieldCustomMessage) } + if m.skip_monitoring != nil { + fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + } if m.description != nil { fields = append(fields, errorpassthroughrule.FieldDescription) } @@ -6663,6 +6703,8 @@ func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { return m.PassthroughBody() case errorpassthroughrule.FieldCustomMessage: return m.CustomMessage() + case errorpassthroughrule.FieldSkipMonitoring: + return m.SkipMonitoring() case errorpassthroughrule.FieldDescription: return m.Description() } @@ -6700,6 +6742,8 @@ func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string return m.OldPassthroughBody(ctx) case errorpassthroughrule.FieldCustomMessage: return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldSkipMonitoring: + return m.OldSkipMonitoring(ctx) case errorpassthroughrule.FieldDescription: return m.OldDescription(ctx) } @@ -6802,6 +6846,13 @@ func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) er } m.SetCustomMessage(v) return nil + case errorpassthroughrule.FieldSkipMonitoring: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSkipMonitoring(v) + return nil case errorpassthroughrule.FieldDescription: v, ok := value.(string) if !ok { @@ -6963,6 +7014,9 @@ func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { case errorpassthroughrule.FieldCustomMessage: m.ResetCustomMessage() return nil + case errorpassthroughrule.FieldSkipMonitoring: + m.ResetSkipMonitoring() + return nil case errorpassthroughrule.FieldDescription: m.ResetDescription() return nil @@ -15007,6 +15061,7 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -16633,6 +16688,42 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { + m.cache_ttl_overridden = &b +} + +// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation. +func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) { + v := m.cache_ttl_overridden + if v == nil { + return + } + return *v, true +} + +// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err) + } + return oldValue.CacheTTLOverridden, nil +} + +// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field. +func (m *UsageLogMutation) ResetCacheTTLOverridden() { + m.cache_ttl_overridden = nil +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16838,7 +16929,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -16926,6 +17017,9 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.cache_ttl_overridden != nil { + fields = append(fields, usagelog.FieldCacheTTLOverridden) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -16995,6 +17089,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldCacheTTLOverridden: + return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -17064,6 +17160,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldCacheTTLOverridden: + return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -17278,6 +17376,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldCacheTTLOverridden: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheTTLOverridden(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -17691,6 +17796,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldCacheTTLOverridden: + m.ResetCacheTTLOverridden() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index e5c34929..d96f9a00 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -326,6 +326,10 @@ func init() { errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) + // errorpassthroughruleDescSkipMonitoring is the schema descriptor for skip_monitoring field. + errorpassthroughruleDescSkipMonitoring := errorpassthroughruleFields[11].Descriptor() + // errorpassthroughrule.DefaultSkipMonitoring holds the default value on creation for the skip_monitoring field. + errorpassthroughrule.DefaultSkipMonitoring = errorpassthroughruleDescSkipMonitoring.Default.(bool) groupMixin := schema.Group{}.Mixin() groupMixinHooks1 := groupMixin[1].Hooks() group.Hooks[0] = groupMixinHooks1[0] @@ -775,8 +779,12 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. + usagelogDescCacheTTLOverridden := usagelogFields[29].Descriptor() + // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. + usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[29].Descriptor() + usagelogDescCreatedAt := usagelogFields[30].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go index 4a861f38..63a81230 100644 --- a/backend/ent/schema/error_passthrough_rule.go +++ b/backend/ent/schema/error_passthrough_rule.go @@ -105,6 +105,12 @@ func (ErrorPassthroughRule) Fields() []ent.Field { Optional(). Nillable(), + // skip_monitoring: 是否跳过运维监控记录 + // true: 匹配此规则的错误不会被记录到 ops_error_logs + // false: 正常记录到运维监控(默认行为) + field.Bool("skip_monitoring"). + Default(false), + // description: 规则描述,用于说明规则的用途 field.Text("description"). Optional(). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index fc7c7165..a5032605 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -119,6 +119,10 @@ func (UsageLog) Fields() []ent.Field { Optional(). Nillable(), + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + field.Bool("cache_ttl_overridden"). + Default(false), + // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). Default(time.Now). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 81c466b4..59b25b99 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,8 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. + CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case usagelog.FieldStream: + case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden: values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) @@ -378,6 +380,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldCacheTTLOverridden: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) + } else if value.Valid { + _m.CacheTTLOverridden = value.Bool + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -548,6 +556,9 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + builder.WriteString("cache_ttl_overridden=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 980f1e58..fca720d2 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,8 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. + FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -155,6 +157,7 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldCacheTTLOverridden, FieldCreatedAt, } @@ -211,6 +214,8 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. + DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -368,6 +373,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. +func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 28e2ab4c..ae832959 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. +func CacheTTLOverridden(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1440,6 +1445,16 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + +// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index a17d6507..5b9cdf14 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { + _c.mutation.SetCacheTTLOverridden(v) + return _c +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate { + if v != nil { + _c.SetCacheTTLOverridden(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -531,6 +545,10 @@ func (_c *UsageLogCreate) defaults() { v := usagelog.DefaultImageCount _c.mutation.SetImageCount(v) } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + v := usagelog.DefaultCacheTTLOverridden + _c.mutation.SetCacheTTLOverridden(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := usagelog.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -627,6 +645,9 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -762,6 +783,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + _node.CacheTTLOverridden = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1407,6 +1432,18 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldCacheTTLOverridden, v) + return u +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheTTLOverridden) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2040,6 +2077,20 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2839,6 +2890,20 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 571a7b3c..22f3cb31 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,20 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -894,6 +908,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1639,6 +1656,20 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -1951,6 +1982,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index b5d1dd0a..34397696 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 85400c6f..0fae04ac 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -156,7 +156,12 @@ func (h *AccountHandler) List(c *gin.Context) { search = search[:100] } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) + var groupID int64 + if groupIDStr := c.Query("group"); groupIDStr != "" { + groupID, _ = strconv.ParseInt(groupIDStr, 10, 64) + } + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) if err != nil { response.ErrorFrom(c, err) return @@ -1429,7 +1434,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index cbbfe942..d44c99ea 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } diff --git a/backend/internal/handler/admin/antigravity_oauth_handler.go b/backend/internal/handler/admin/antigravity_oauth_handler.go index 18541684..7488965d 100644 --- a/backend/internal/handler/admin/antigravity_oauth_handler.go +++ b/backend/internal/handler/admin/antigravity_oauth_handler.go @@ -65,3 +65,27 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) { response.Success(c, tokenInfo) } + +// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token +type AntigravityRefreshTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` + ProxyID *int64 `json:"proxy_id"` +} + +// RefreshToken validates an Antigravity refresh token and returns full token info +// POST /api/v1/admin/antigravity/oauth/refresh-token +func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) { + var req AntigravityRefreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "请求无效: "+err.Error()) + return + } + + tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, tokenInfo) +} diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go index c32db561..25aaa5c7 100644 --- a/backend/internal/handler/admin/error_passthrough_handler.go +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct { ResponseCode *int `json:"response_code"` PassthroughBody *bool `json:"passthrough_body"` CustomMessage *string `json:"custom_message"` + SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` } @@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct { ResponseCode *int `json:"response_code"` PassthroughBody *bool `json:"passthrough_body"` CustomMessage *string `json:"custom_message"` + SkipMonitoring *bool `json:"skip_monitoring"` Description *string `json:"description"` } @@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) { } else { rule.PassthroughBody = true } + if req.SkipMonitoring != nil { + rule.SkipMonitoring = *req.SkipMonitoring + } rule.ResponseCode = req.ResponseCode rule.CustomMessage = req.CustomMessage rule.Description = req.Description @@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) { ResponseCode: existing.ResponseCode, PassthroughBody: existing.PassthroughBody, CustomMessage: existing.CustomMessage, + SkipMonitoring: existing.SkipMonitoring, Description: existing.Description, } @@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) { if req.Description != nil { rule.Description = req.Description } + if req.SkipMonitoring != nil { + rule.SkipMonitoring = *req.SkipMonitoring + } // 确保切片不为 nil if rule.ErrorCodes == nil { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index e229385f..02752fea 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -202,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { writer := csv.NewWriter(&buf) // Write header - if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil { + if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil { response.InternalError(c, "Failed to export redeem codes: "+err.Error()) return } @@ -213,6 +213,10 @@ func (h *RedeemHandler) Export(c *gin.Context) { if code.UsedBy != nil { usedBy = fmt.Sprintf("%d", *code.UsedBy) } + usedByEmail := "" + if code.User != nil { + usedByEmail = code.User.Email + } usedAt := "" if code.UsedAt != nil { usedAt = code.UsedAt.Format("2006-01-02 15:04:05") @@ -224,6 +228,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { fmt.Sprintf("%.2f", code.Value), code.Status, usedBy, + usedByEmail, usedAt, code.CreatedAt.Format("2006-01-02 15:04:05"), }); err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2caf6847..eee5910e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -211,6 +211,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { enabled := true out.EnableSessionIDMasking = &enabled } + // 缓存 TTL 强制替换 + if a.IsCacheTTLOverrideEnabled() { + enabled := true + out.CacheTTLOverrideEnabled = &enabled + target := a.GetCacheTTLOverrideTarget() + out.CacheTTLOverrideTarget = &target + } } return out @@ -398,6 +405,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ImageCount: l.ImageCount, ImageSize: l.ImageSize, UserAgent: l.UserAgent, + CacheTTLOverridden: l.CacheTTLOverridden, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 2c1ae83c..0253caf7 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -150,6 +150,11 @@ type Account struct { // 从 extra 字段提取,方便前端显示和编辑 EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费 + CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` + CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -273,6 +278,9 @@ type UsageLog struct { // User-Agent UserAgent *string `json:"user_agent"` + // Cache TTL Override 标记 + CacheTTLOverridden bool `json:"cache_ttl_overridden"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 6900fa55..c2b6bf09 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -235,9 +235,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 var lastFailoverErr *service.UpstreamFailoverError var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + } + for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { @@ -245,6 +253,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } + // Antigravity 单账号退避重试:分组内没有其他可用账号时, + // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 + // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 + if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { + if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { + log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + failedAccountIDs = make(map[int64]struct{}) + // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + } + } if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) } else { @@ -339,11 +360,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { + sameAccountRetryCount[account.ID]++ + log.Printf("Account %d: retryable error %d, same-account retry %d/%d", + account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) + if !sleepSameAccountRetryDelay(c.Request.Context()) { + return + } + continue + } + + // 同账号重试用尽,执行临时封禁并切换账号 + if failoverErr.RetryableOnSameAccount { + h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) + } + + failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) return @@ -396,10 +434,18 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } fallbackUsed := false + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + } + for { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) + sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 var lastFailoverErr *service.UpstreamFailoverError retryWithFallback := false var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 @@ -412,6 +458,19 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } + // Antigravity 单账号退避重试:分组内没有其他可用账号时, + // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 + // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 + if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { + if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { + log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + failedAccountIDs = make(map[int64]struct{}) + // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + } + } if lastFailoverErr != nil { h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) } else { @@ -539,11 +598,28 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { + sameAccountRetryCount[account.ID]++ + log.Printf("Account %d: retryable error %d, same-account retry %d/%d", + account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) + if !sleepSameAccountRetryDelay(c.Request.Context()) { + return + } + continue + } + + // 同账号重试用尽,执行临时封禁并切换账号 + if failoverErr.RetryableOnSameAccount { + h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) + } + + failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) return @@ -823,6 +899,23 @@ func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFa return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) } +const ( + // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) + maxSameAccountRetries = 2 + // sameAccountRetryDelay 同账号重试间隔 + sameAccountRetryDelay = 500 * time.Millisecond +) + +// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。 +func sleepSameAccountRetryDelay(ctx context.Context) bool { + select { + case <-ctx.Done(): + return false + case <-time.After(sameAccountRetryDelay): + return true + } +} + // sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… // 返回 false 表示 context 已取消。 func sleepFailoverDelay(ctx context.Context, switchCount int) bool { @@ -838,6 +931,27 @@ func sleepFailoverDelay(ctx context.Context, switchCount int) bool { } } +// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。 +// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用, +// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试 +// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。 +// 返回 false 表示 context 已取消。 +func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool { + // 固定短延时:2s + // Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数), + // Handler 层只需短暂间隔后重新进入 Service 层即可。 + const delay = 2 * time.Second + + log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount) + + select { + case <-ctx.Done(): + return false + case <-time.After(delay): + return true + } +} + func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody @@ -857,6 +971,10 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *se msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } diff --git a/backend/internal/handler/gateway_handler_single_account_retry_test.go b/backend/internal/handler/gateway_handler_single_account_retry_test.go new file mode 100644 index 00000000..96aa14c6 --- /dev/null +++ b/backend/internal/handler/gateway_handler_single_account_retry_test.go @@ -0,0 +1,51 @@ +package handler + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// sleepAntigravitySingleAccountBackoff 测试 +// --------------------------------------------------------------------------- + +func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) { + ctx := context.Background() + start := time.Now() + ok := sleepAntigravitySingleAccountBackoff(ctx, 1) + elapsed := time.Since(start) + + require.True(t, ok, "should return true when context is not canceled") + // 固定延迟 2s + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s") + require.Less(t, elapsed, 5*time.Second, "should not wait too long") +} + +func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + ok := sleepAntigravitySingleAccountBackoff(ctx, 1) + elapsed := time.Since(start) + + require.False(t, ok, "should return false when context is canceled") + require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel") +} + +func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) { + // 验证不同 retryCount 都使用固定 2s 延迟 + ctx := context.Background() + + start := time.Now() + ok := sleepAntigravitySingleAccountBackoff(ctx, 5) + elapsed := time.Since(start) + + require.True(t, ok) + // 即使 retryCount=5,延迟仍然是固定的 2s + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond) + require.Less(t, elapsed, 5*time.Second) +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d5149f22..3d25505b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -327,6 +327,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var lastFailoverErr *service.UpstreamFailoverError var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 + // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 + if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + } + for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { @@ -334,6 +341,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } + // Antigravity 单账号退避重试:分组内没有其他可用账号时, + // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 + // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 + if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { + if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { + log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + failedAccountIDs = make(map[int64]struct{}) + // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 + ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + c.Request = c.Request.WithContext(ctx) + continue + } + } h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } @@ -534,6 +554,10 @@ func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverE msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + googleError(c, respCode, msg) return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 835297b8..c08a8b0e 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -354,6 +354,10 @@ func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverE msg = *rule.CustomMessage } + if rule.SkipMonitoring { + c.Set(service.OpsSkipPassthroughKey, true) + } + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) return } diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 36ffde63..cb62ceae 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -537,6 +537,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { // Store request headers/body only when an upstream error occurred to keep overhead minimal. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + // Skip logging if a passthrough rule with skip_monitoring=true matched. + if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { + if skip, _ := v.(bool); skip { + return + } + } + enqueueOpsErrorLog(ops, entry, requestBody) return } @@ -544,6 +551,13 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { body := w.buf.Bytes() parsed := parseOpsErrorResponse(body) + // Skip logging if a passthrough rule with skip_monitoring=true matched. + if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { + if skip, _ := v.(bool); skip { + return + } + } + // Skip logging if the error should be filtered based on settings if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) { return diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go index d4fc16e3..620736cd 100644 --- a/backend/internal/model/error_passthrough_rule.go +++ b/backend/internal/model/error_passthrough_rule.go @@ -18,6 +18,7 @@ type ErrorPassthroughRule struct { ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) + SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录 Description *string `json:"description"` // 规则描述 CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 8a29cd10..7c127b90 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -27,7 +27,7 @@ type ClaudeMessage struct { // ThinkingConfig Thinking 配置 type ThinkingConfig struct { - Type string `json:"type"` // "enabled" or "disabled" + Type string `json:"type"` // "enabled" / "adaptive" / "disabled" BudgetTokens int `json:"budget_tokens,omitempty"` // thinking budget } diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index a6279b11..ac32fae5 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -115,6 +115,23 @@ type LoadCodeAssistResponse struct { IneligibleTiers []*IneligibleTier `json:"ineligibleTiers,omitempty"` } +// OnboardUserRequest onboardUser 请求 +type OnboardUserRequest struct { + TierID string `json:"tierId"` + Metadata struct { + IDEType string `json:"ideType"` + Platform string `json:"platform,omitempty"` + PluginType string `json:"pluginType,omitempty"` + } `json:"metadata"` +} + +// OnboardUserResponse onboardUser 响应 +type OnboardUserResponse struct { + Name string `json:"name,omitempty"` + Done bool `json:"done"` + Response map[string]any `json:"response,omitempty"` +} + // GetTier 获取账户类型 // 优先返回 paidTier(付费订阅级别),否则返回 currentTier func (r *LoadCodeAssistResponse) GetTier() string { @@ -361,6 +378,117 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, lastErr } +// OnboardUser 触发账号 onboarding,并返回 project_id +// 说明: +// 1) 部分账号 loadCodeAssist 不会立即返回 cloudaicompanionProject; +// 2) 这时需要调用 onboardUser 完成初始化,之后才能拿到 project_id。 +func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) { + tierID = strings.TrimSpace(tierID) + if tierID == "" { + return "", fmt.Errorf("tier_id 为空") + } + + reqBody := OnboardUserRequest{TierID: tierID} + reqBody.Metadata.IDEType = "ANTIGRAVITY" + reqBody.Metadata.Platform = "PLATFORM_UNSPECIFIED" + reqBody.Metadata.PluginType = "GEMINI" + + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("序列化请求失败: %w", err) + } + + availableURLs := BaseURLs + var lastErr error + + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:onboardUser" + + for attempt := 1; attempt <= 5; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(bodyBytes)) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + break + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("onboardUser 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + log.Printf("[antigravity] onboardUser URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + break + } + return "", lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + return "", fmt.Errorf("读取响应失败: %w", err) + } + + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + log.Printf("[antigravity] onboardUser URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + break + } + + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("onboardUser 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + return "", lastErr + } + + var onboardResp OnboardUserResponse + if err := json.Unmarshal(respBodyBytes, &onboardResp); err != nil { + lastErr = fmt.Errorf("onboardUser 响应解析失败: %w", err) + return "", lastErr + } + + if onboardResp.Done { + if projectID := extractProjectIDFromOnboardResponse(onboardResp.Response); projectID != "" { + DefaultURLAvailability.MarkSuccess(baseURL) + return projectID, nil + } + lastErr = fmt.Errorf("onboardUser 完成但未返回 project_id") + return "", lastErr + } + + // done=false 时等待后重试(与 CLIProxyAPI 行为一致) + select { + case <-time.After(2 * time.Second): + case <-ctx.Done(): + return "", ctx.Err() + } + } + } + + if lastErr != nil { + return "", lastErr + } + return "", fmt.Errorf("onboardUser 未返回 project_id") +} + +func extractProjectIDFromOnboardResponse(resp map[string]any) string { + if len(resp) == 0 { + return "" + } + + if v, ok := resp["cloudaicompanionProject"]; ok { + switch project := v.(type) { + case string: + return strings.TrimSpace(project) + case map[string]any: + if id, ok := project["id"].(string); ok { + return strings.TrimSpace(id) + } + } + } + + return "" +} + // ModelQuotaInfo 模型配额信息 type ModelQuotaInfo struct { RemainingFraction float64 `json:"remainingFraction"` diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go new file mode 100644 index 00000000..ac30093d --- /dev/null +++ b/backend/internal/pkg/antigravity/client_test.go @@ -0,0 +1,76 @@ +package antigravity + +import ( + "testing" +) + +func TestExtractProjectIDFromOnboardResponse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + resp map[string]any + want string + }{ + { + name: "nil response", + resp: nil, + want: "", + }, + { + name: "empty response", + resp: map[string]any{}, + want: "", + }, + { + name: "project as string", + resp: map[string]any{ + "cloudaicompanionProject": "my-project-123", + }, + want: "my-project-123", + }, + { + name: "project as string with spaces", + resp: map[string]any{ + "cloudaicompanionProject": " my-project-123 ", + }, + want: "my-project-123", + }, + { + name: "project as map with id", + resp: map[string]any{ + "cloudaicompanionProject": map[string]any{ + "id": "proj-from-map", + }, + }, + want: "proj-from-map", + }, + { + name: "project as map without id", + resp: map[string]any{ + "cloudaicompanionProject": map[string]any{ + "name": "some-name", + }, + }, + want: "", + }, + { + name: "missing cloudaicompanionProject key", + resp: map[string]any{ + "otherField": "value", + }, + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := extractProjectIDFromOnboardResponse(tc.resp) + if got != tc.want { + t.Fatalf("extractProjectIDFromOnboardResponse() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index c1cc998c..32495827 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -155,6 +155,7 @@ type GeminiUsageMetadata struct { CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"` CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"` TotalTokenCount int `json:"totalTokenCount,omitempty"` + ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费) } // GeminiGroundingMetadata Gemini grounding 元数据(Web Search) diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 65f45cfc..3ba04b95 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -64,6 +64,10 @@ const MaxTokensBudgetPadding = 1000 // Gemini 2.5 Flash thinking budget 上限 const Gemini25FlashThinkingBudgetLimit = 24576 +// 对于 Antigravity 的 Claude(budget-only)模型,该语义最终等价为 thinkingBudget=24576。 +// 这里复用相同数值以保持行为一致。 +const ClaudeAdaptiveHighThinkingBudgetTokens = Gemini25FlashThinkingBudgetLimit + // ensureMaxTokensGreaterThanBudget 确保 max_tokens > budget_tokens // Claude API 要求启用 thinking 时,max_tokens 必须大于 thinking.budget_tokens // 返回调整后的 maxTokens 和是否进行了调整 @@ -96,7 +100,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map } // 检测是否启用 thinking - isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + isThinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures @@ -198,8 +202,7 @@ type modelInfo struct { // modelInfoMap 模型前缀 → 模型信息映射 // 只有在此映射表中的模型才会注入身份提示词 -// 注意:当前 claude-opus-4-6 会被映射到 claude-opus-4-5-thinking, -// 但保留此条目以便后续 Antigravity 上游支持 4.6 时快速切换 +// 注意:模型映射逻辑在网关层完成;这里仅用于按模型前缀判断是否注入身份提示词。 var modelInfoMap = map[string]modelInfo{ "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, @@ -271,6 +274,21 @@ func filterOpenCodePrompt(text string) string { return "" } +// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 +var systemBlockFilterPrefixes = []string{ + "x-anthropic-billing-header", +} + +// filterSystemBlockByPrefix 如果文本匹配过滤前缀,返回空字符串 +func filterSystemBlockByPrefix(text string) string { + for _, prefix := range systemBlockFilterPrefixes { + if strings.HasPrefix(text, prefix) { + return "" + } + } + return text +} + // buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart @@ -287,8 +305,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词 - filtered := filterOpenCodePrompt(sysStr) + // 过滤 OpenCode 默认提示词和黑名单前缀 + filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(sysStr)) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -302,8 +320,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } - // 过滤 OpenCode 默认提示词 - filtered := filterOpenCodePrompt(block.Text) + // 过滤 OpenCode 默认提示词和黑名单前缀 + filtered := filterSystemBlockByPrefix(filterOpenCodePrompt(block.Text)) if filtered != "" { userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) } @@ -578,6 +596,10 @@ func maxOutputTokensLimit(model string) int { return maxOutputTokensUpperBound } +func isAntigravityOpus46Model(model string) bool { + return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6") +} + func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { maxLimit := maxOutputTokensLimit(req.Model) config := &GeminiGenerationConfig{ @@ -591,25 +613,36 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } // Thinking 配置 - if req.Thinking != nil && req.Thinking.Type == "enabled" { + if req.Thinking != nil && (req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive") { config.ThinkingConfig = &GeminiThinkingConfig{ IncludeThoughts: true, } + + // - thinking.type=enabled:budget_tokens>0 用显式预算 + // - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576) + budget := -1 if req.Thinking.BudgetTokens > 0 { - budget := req.Thinking.BudgetTokens + budget = req.Thinking.BudgetTokens + } + if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) { + budget = ClaudeAdaptiveHighThinkingBudgetTokens + } + + // 正预算需要做上限与 max_tokens 约束;动态预算(-1)直接透传给上游。 + if budget > 0 { // gemini-2.5-flash 上限 if strings.Contains(req.Model, "gemini-2.5-flash") && budget > Gemini25FlashThinkingBudgetLimit { budget = Gemini25FlashThinkingBudgetLimit } - config.ThinkingConfig.ThinkingBudget = budget - // 自动修正:max_tokens 必须大于 budget_tokens + // 自动修正:max_tokens 必须大于 budget_tokens(Claude 上游要求) if adjusted, ok := ensureMaxTokensGreaterThanBudget(config.MaxOutputTokens, budget); ok { log.Printf("[Antigravity] Auto-adjusted max_tokens from %d to %d (must be > budget_tokens=%d)", config.MaxOutputTokens, adjusted, budget) config.MaxOutputTokens = adjusted } } + config.ThinkingConfig.ThinkingBudget = budget } if config.MaxOutputTokens > maxLimit { diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index f938b47f..f267e0e1 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -259,3 +259,93 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { }) } } + +func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { + tests := []struct { + name string + model string + thinking *ThinkingConfig + wantBudget int + wantPresent bool + }{ + { + name: "enabled without budget defaults to dynamic (-1)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled"}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "enabled with budget uses the provided value", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1024}, + wantBudget: 1024, + wantPresent: true, + }, + { + name: "enabled with -1 budget uses dynamic (-1)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: -1}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "adaptive on opus4.6 maps to high budget (24576)", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "adaptive", BudgetTokens: 20000}, + wantBudget: ClaudeAdaptiveHighThinkingBudgetTokens, + wantPresent: true, + }, + { + name: "adaptive on non-opus model keeps default dynamic (-1)", + model: "claude-sonnet-4-5-thinking", + thinking: &ThinkingConfig{Type: "adaptive"}, + wantBudget: -1, + wantPresent: true, + }, + { + name: "disabled does not emit thinkingConfig", + model: "claude-opus-4-6-thinking", + thinking: &ThinkingConfig{Type: "disabled", BudgetTokens: 1024}, + wantBudget: 0, + wantPresent: false, + }, + { + name: "nil thinking does not emit thinkingConfig", + model: "claude-opus-4-6-thinking", + thinking: nil, + wantBudget: 0, + wantPresent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &ClaudeRequest{ + Model: tt.model, + Thinking: tt.thinking, + } + cfg := buildGenerationConfig(req) + if cfg == nil { + t.Fatalf("expected non-nil generationConfig") + } + + if tt.wantPresent { + if cfg.ThinkingConfig == nil { + t.Fatalf("expected thinkingConfig to be present") + } + if !cfg.ThinkingConfig.IncludeThoughts { + t.Fatalf("expected includeThoughts=true") + } + if cfg.ThinkingConfig.ThinkingBudget != tt.wantBudget { + t.Fatalf("expected thinkingBudget=%d, got %d", tt.wantBudget, cfg.ThinkingConfig.ThinkingBudget) + } + return + } + + if cfg.ThinkingConfig != nil { + t.Fatalf("expected thinkingConfig to be nil, got %+v", cfg.ThinkingConfig) + } + }) + } +} diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index eb16f09d..463033f1 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -279,7 +279,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index b384658a..677435ad 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -85,7 +85,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { if geminiResp.UsageMetadata != nil { cached := geminiResp.UsageMetadata.CachedContentTokenCount p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached - p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount p.cacheReadTokens = cached } @@ -146,7 +146,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte if v1Resp.Response.UsageMetadata != nil { cached := v1Resp.Response.UsageMetadata.CachedContentTokenCount usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached - usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount usage.CacheReadInputTokens = cached } diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e..423ad925 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -10,6 +10,7 @@ const ( BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaTokenCounting = "token-counting-2024-11-01" + BetaContext1M = "context-1m-2025-08-07" ) // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header @@ -77,6 +78,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-sonnet-4-6", + Type: "model", + DisplayName: "Claude Sonnet 4.6", + CreatedAt: "2026-02-18T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 9bf563e7..0c4d82f7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -28,4 +28,8 @@ const ( // IsMaxTokensOneHaikuRequest 标识当前请求是否为 max_tokens=1 + haiku 模型的探测请求 // 用于 ClaudeCodeOnly 验证绕过(绕过 system prompt 检查,但仍需验证 User-Agent) IsMaxTokensOneHaikuRequest Key = "ctx_is_max_tokens_one_haiku" + + // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 + // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 + SingleAccountRetry Key = "ctx_single_account_retry" ) diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index d73e0521..e3e70213 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0) } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -448,11 +448,19 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati q = q.Where(dbaccount.TypeEQ(accountType)) } if status != "" { - q = q.Where(dbaccount.StatusEQ(status)) + switch status { + case "rate_limited": + q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) + default: + q = q.Where(dbaccount.StatusEQ(status)) + } } if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } + if groupID > 0 { + q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index a054b6d6..4f9d0152 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go index a58ab60f..ae989359 100644 --- a/backend/internal/repository/error_passthrough_repo.go +++ b/backend/internal/repository/error_passthrough_repo.go @@ -54,7 +54,8 @@ func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody) + SetPassthroughBody(rule.PassthroughBody). + SetSkipMonitoring(rule.SkipMonitoring) if len(rule.ErrorCodes) > 0 { builder.SetErrorCodes(rule.ErrorCodes) @@ -90,7 +91,8 @@ func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.Err SetPriority(rule.Priority). SetMatchMode(rule.MatchMode). SetPassthroughCode(rule.PassthroughCode). - SetPassthroughBody(rule.PassthroughBody) + SetPassthroughBody(rule.PassthroughBody). + SetSkipMonitoring(rule.SkipMonitoring) // 处理可选字段 if len(rule.ErrorCodes) > 0 { @@ -149,6 +151,7 @@ func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model Platforms: e.Platforms, PassthroughCode: e.PassthroughCode, PassthroughBody: e.PassthroughBody, + SkipMonitoring: e.SkipMonitoring, CreatedAt: e.CreatedAt, UpdatedAt: e.UpdatedAt, } diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index a3a048c3..934a3095 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -6,6 +6,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" ) @@ -106,7 +107,12 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin q = q.Where(redeemcode.StatusEQ(status)) } if search != "" { - q = q.Where(redeemcode.CodeContainsFold(search)) + q = q.Where( + redeemcode.Or( + redeemcode.CodeContainsFold(search), + redeemcode.HasUserWith(user.EmailContainsFold(search)), + ), + ) } total, err := q.Count(ctx) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 2db1764f..c3e5ae85 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at" type usageLogRepository struct { client *dbent.Client @@ -115,6 +115,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) image_count, image_size, reasoning_effort, + cache_ttl_overridden, created_at ) VALUES ( $1, $2, $3, $4, $5, @@ -122,7 +123,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -173,6 +174,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.ImageCount, imageSize, reasoningEffort, + log.CacheTTLOverridden, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -2195,6 +2197,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e imageCount int imageSize sql.NullString reasoningEffort sql.NullString + cacheTTLOverridden bool createdAt time.Time ) @@ -2230,6 +2233,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &imageCount, &imageSize, &reasoningEffort, + &cacheTTLOverridden, &createdAt, ); err != nil { return nil, err @@ -2258,6 +2262,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e BillingType: int8(billingType), Stream: stream, ImageCount: imageCount, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 654bd16b..17674291 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -10,6 +10,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/apikey" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -191,6 +192,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. dbuser.EmailContainsFold(filters.Search), dbuser.UsernameContainsFold(filters.Search), dbuser.NotesContainsFold(filters.Search), + dbuser.HasAPIKeysWith(apikey.KeyContainsFold(filters.Search)), ), ) } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 68f001a2..6b607083 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -401,6 +401,7 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -936,7 +937,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 7d82f183..f1dd51af 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -70,7 +70,15 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { } } - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") + allowHeaders := []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key"} + + // openai node sdk + openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"} + for _, prop := range openAIProperties { + allowHeaders = append(allowHeaders, "x-stainless-"+prop) + } + + c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", ")) c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") // 处理预检请求 diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 39c5d2fc..4509b4bc 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -281,6 +281,7 @@ func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL) antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode) + antigravity.POST("/oauth/refresh-token", h.Admin.AntigravityOAuth.RefreshToken) } } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 138d5bcb..fa3ce738 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -752,6 +752,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) +func (a *Account) IsCacheTTLOverrideEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["cache_ttl_override_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型 +// 返回 "5m" 或 "1h",默认 "5m" +func (a *Account) GetCacheTTLOverrideTarget() string { + if a.Extra == nil { + return "5m" + } + if v, ok := a.Extra["cache_ttl_override_target"]; ok { + if target, ok := v.(string); ok && (target == "5m" || target == "1h") { + return target + } + } + return "5m" +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 6c0cca31..f192fba4 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -32,7 +32,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 25bd0576..a420d46b 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -75,7 +75,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 06354e1e..1f6e91e5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -39,7 +39,7 @@ type AdminService interface { UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -1021,9 +1021,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) if err != nil { return nil, 0, err } diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index d661b710..ff58fd01 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform @@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 014b3c86..1d87f4b1 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -16,10 +16,12 @@ import ( "os" "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -39,6 +41,12 @@ const ( antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) + // MODEL_CAPACITY_EXHAUSTED 专用重试参数 + // 模型容量不足时,所有账号共享同一容量池,切换账号无意义 + // 使用固定 1s 间隔重试,最多重试 60 次 + antigravityModelCapacityRetryMaxAttempts = 60 + antigravityModelCapacityRetryWait = 1 * time.Second + // Google RPC 状态和类型常量 googleRPCStatusResourceExhausted = "RESOURCE_EXHAUSTED" googleRPCStatusUnavailable = "UNAVAILABLE" @@ -46,6 +54,22 @@ const ( googleRPCTypeErrorInfo = "type.googleapis.com/google.rpc.ErrorInfo" googleRPCReasonModelCapacityExhausted = "MODEL_CAPACITY_EXHAUSTED" googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED" + + // 单账号 503 退避重试:Service 层原地重试的最大次数 + // 在 handleSmartRetry 中,对于 shouldRateLimitModel(长延迟 ≥ 7s)的情况, + // 多账号模式下会设限流+切换账号;但单账号模式下改为原地等待+重试。 + antigravitySingleAccountSmartRetryMaxAttempts = 3 + + // 单账号 503 退避重试:原地重试时单次最大等待时间 + // 防止上游返回过长的 retryDelay 导致请求卡住太久 + antigravitySingleAccountSmartRetryMaxWait = 15 * time.Second + + // 单账号 503 退避重试:原地重试的总累计等待时间上限 + // 超过此上限将不再重试,直接返回 503 + antigravitySingleAccountSmartRetryTotalMaxWait = 30 * time.Second + + // MODEL_CAPACITY_EXHAUSTED 全局去重:重试全部失败后的 cooldown 时间 + antigravityModelCapacityCooldown = 10 * time.Second ) // antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写) @@ -54,8 +78,15 @@ var antigravityPassthroughErrorMessages = []string{ "prompt is too long", } +// MODEL_CAPACITY_EXHAUSTED 全局去重:避免多个并发请求同时对同一模型进行容量耗尽重试 +var ( + modelCapacityExhaustedMu sync.RWMutex + modelCapacityExhaustedUntil = make(map[string]time.Time) // modelName -> cooldown until +) + const ( antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" + antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" ) @@ -117,6 +148,20 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// resolveAntigravityForwardBaseURL 解析转发用 base URL。 +// 默认使用 daily(ForwardBaseURLs 的首个地址);当环境变量为 prod 时使用第二个地址。 +func resolveAntigravityForwardBaseURL() string { + baseURLs := antigravity.ForwardBaseURLs() + if len(baseURLs) == 0 { + return "" + } + mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv))) + if mode == "prod" && len(baseURLs) > 1 { + return baseURLs[1] + } + return baseURLs[0] +} + // smartRetryAction 智能重试的处理结果 type smartRetryAction int @@ -144,10 +189,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } // 判断是否触发智能重试 - shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName := shouldTriggerAntigravitySmartRetry(p.account, respBody) + shouldSmartRetry, shouldRateLimitModel, waitDuration, modelName, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(p.account, respBody) // 情况1: retryDelay >= 阈值,限流模型并切换账号 if shouldRateLimitModel { + // 单账号 503 退避重试模式:不设限流、不切换账号,改为原地等待+重试 + // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 + // 多账号场景下切换账号是最优选择,但单账号场景下设限流毫无意义(只会导致双重等待)。 + if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { + return s.handleSingleAccountRetryInPlace(p, resp, respBody, baseURL, waitDuration, modelName) + } + rateLimitDuration := waitDuration if rateLimitDuration <= 0 { rateLimitDuration = antigravityDefaultRateLimitDuration @@ -174,20 +226,48 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } } - // 情况2: retryDelay < 阈值,智能重试(最多 antigravitySmartRetryMaxAttempts 次) + // 情况2: retryDelay < 阈值(或 MODEL_CAPACITY_EXHAUSTED),智能重试 if shouldSmartRetry { var lastRetryResp *http.Response var lastRetryBody []byte - for attempt := 1; attempt <= antigravitySmartRetryMaxAttempts; attempt++ { - log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", - p.prefix, resp.StatusCode, attempt, antigravitySmartRetryMaxAttempts, waitDuration, modelName, p.account.ID) + // MODEL_CAPACITY_EXHAUSTED 使用独立的重试参数(60 次,固定 1s 间隔) + maxAttempts := antigravitySmartRetryMaxAttempts + if isModelCapacityExhausted { + maxAttempts = antigravityModelCapacityRetryMaxAttempts + waitDuration = antigravityModelCapacityRetryWait + // 全局去重:如果其他 goroutine 已在重试同一模型且尚在 cooldown 中,直接返回 503 + if modelName != "" { + modelCapacityExhaustedMu.RLock() + cooldownUntil, exists := modelCapacityExhaustedUntil[modelName] + modelCapacityExhaustedMu.RUnlock() + if exists && time.Now().Before(cooldownUntil) { + log.Printf("%s status=%d model_capacity_exhausted_dedup model=%s account=%d cooldown_until=%v (skip retry)", + p.prefix, resp.StatusCode, modelName, p.account.ID, cooldownUntil.Format("15:04:05")) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + }, + } + } + } + } + + for attempt := 1; attempt <= maxAttempts; attempt++ { + log.Printf("%s status=%d oauth_smart_retry attempt=%d/%d delay=%v model=%s account=%d", + p.prefix, resp.StatusCode, attempt, maxAttempts, waitDuration, modelName, p.account.ID) + + timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): + timer.Stop() log.Printf("%s status=context_canceled_during_smart_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} - case <-time.After(waitDuration): + case <-timer.C: } // 智能重试:创建新请求 @@ -207,13 +287,19 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { - log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, antigravitySmartRetryMaxAttempts) + log.Printf("%s status=%d smart_retry_success attempt=%d/%d", p.prefix, retryResp.StatusCode, attempt, maxAttempts) + // 重试成功,清除 MODEL_CAPACITY_EXHAUSTED cooldown + if isModelCapacityExhausted && modelName != "" { + modelCapacityExhaustedMu.Lock() + delete(modelCapacityExhaustedUntil, modelName) + modelCapacityExhaustedMu.Unlock() + } return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} } // 网络错误时,继续重试 if retryErr != nil || retryResp == nil { - log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySmartRetryMaxAttempts, retryErr) + log.Printf("%s status=smart_retry_network_error attempt=%d/%d error=%v", p.prefix, attempt, maxAttempts, retryErr) continue } @@ -223,20 +309,20 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } lastRetryResp = retryResp if retryResp != nil { - lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() } - // 解析新的重试信息,用于下次重试的等待时间 - if attempt < antigravitySmartRetryMaxAttempts && lastRetryBody != nil { - newShouldRetry, _, newWaitDuration, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) + // 解析新的重试信息,用于下次重试的等待时间(MODEL_CAPACITY_EXHAUSTED 使用固定循环,跳过) + if !isModelCapacityExhausted && attempt < maxAttempts && lastRetryBody != nil { + newShouldRetry, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) if newShouldRetry && newWaitDuration > 0 { waitDuration = newWaitDuration } } } - // 所有重试都失败,限流当前模型并切换账号 + // 所有重试都失败 rateLimitDuration := waitDuration if rateLimitDuration <= 0 { rateLimitDuration = antigravityDefaultRateLimitDuration @@ -245,8 +331,45 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam if retryBody == nil { retryBody = respBody } + + // MODEL_CAPACITY_EXHAUSTED:模型容量不足,切换账号无意义 + // 直接返回上游错误响应,不设置模型限流,不切换账号 + if isModelCapacityExhausted { + // 设置 cooldown,让后续请求快速失败,避免重复重试 + if modelName != "" { + modelCapacityExhaustedMu.Lock() + modelCapacityExhaustedUntil[modelName] = time.Now().Add(antigravityModelCapacityCooldown) + modelCapacityExhaustedMu.Unlock() + } + log.Printf("%s status=%d smart_retry_exhausted_model_capacity attempts=%d model=%s account=%d body=%s (model capacity exhausted, not switching account)", + p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } + } + + // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, + // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 + if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { + log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", + p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } + } + log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", - p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) + p.prefix, resp.StatusCode, maxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { @@ -279,25 +402,163 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam return &smartRetryResult{action: smartRetryActionContinue} } +// handleSingleAccountRetryInPlace 单账号 503 退避重试的原地重试逻辑。 +// +// 在多账号场景下,收到 503 + 长 retryDelay(≥ 7s)时会设置模型限流 + 切换账号; +// 但在单账号场景下,设限流毫无意义(因为切换回来的还是同一个账号,还要等限流过期)。 +// 此方法改为在 Service 层原地等待 + 重试,避免双重等待问题: +// +// 旧流程:Service 设限流 → Handler 退避等待 → Service 等限流过期 → 再请求(总耗时 = 退避 + 限流) +// 新流程:Service 直接等 retryDelay → 重试 → 成功/再等 → 重试...(总耗时 ≈ 实际 retryDelay × 重试次数) +// +// 约束: +// - 单次等待不超过 antigravitySingleAccountSmartRetryMaxWait +// - 总累计等待不超过 antigravitySingleAccountSmartRetryTotalMaxWait +// - 最多重试 antigravitySingleAccountSmartRetryMaxAttempts 次 +func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( + p antigravityRetryLoopParams, + resp *http.Response, + respBody []byte, + baseURL string, + waitDuration time.Duration, + modelName string, +) *smartRetryResult { + // 限制单次等待时间 + if waitDuration > antigravitySingleAccountSmartRetryMaxWait { + waitDuration = antigravitySingleAccountSmartRetryMaxWait + } + if waitDuration < antigravitySmartRetryMinWait { + waitDuration = antigravitySmartRetryMinWait + } + + log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", + p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration) + + var lastRetryResp *http.Response + var lastRetryBody []byte + totalWaited := time.Duration(0) + + for attempt := 1; attempt <= antigravitySingleAccountSmartRetryMaxAttempts; attempt++ { + // 检查累计等待是否超限 + if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait { + remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited + if remaining <= 0 { + log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", + p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait) + break + } + waitDuration = remaining + } + + log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", + p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) + + timer := time.NewTimer(waitDuration) + select { + case <-p.ctx.Done(): + timer.Stop() + log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix) + return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} + case <-timer.C: + } + totalWaited += waitDuration + + // 创建新请求 + retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) + if err != nil { + log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) + break + } + + retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { + log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", + p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited) + // 关闭之前的响应 + if lastRetryResp != nil { + _ = lastRetryResp.Body.Close() + } + return &smartRetryResult{action: smartRetryActionBreakWithResp, resp: retryResp} + } + + // 网络错误时继续重试 + if retryErr != nil || retryResp == nil { + log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v", + p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr) + continue + } + + // 关闭之前的响应 + if lastRetryResp != nil { + _ = lastRetryResp.Body.Close() + } + lastRetryResp = retryResp + lastRetryBody, _ = io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) + _ = retryResp.Body.Close() + + // 解析新的重试信息,更新下次等待时间 + if attempt < antigravitySingleAccountSmartRetryMaxAttempts && lastRetryBody != nil { + _, _, newWaitDuration, _, _ := shouldTriggerAntigravitySmartRetry(p.account, lastRetryBody) + if newWaitDuration > 0 { + waitDuration = newWaitDuration + if waitDuration > antigravitySingleAccountSmartRetryMaxWait { + waitDuration = antigravitySingleAccountSmartRetryMaxWait + } + if waitDuration < antigravitySmartRetryMinWait { + waitDuration = antigravitySmartRetryMinWait + } + } + } + } + + // 所有重试都失败,不设限流,直接返回 503 + // Handler 层的单账号退避循环会做最终处理 + retryBody := lastRetryBody + if retryBody == nil { + retryBody = respBody + } + log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", + p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200)) + + return &smartRetryResult{ + action: smartRetryActionBreakWithResp, + resp: &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryBody)), + }, + } +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { // 预检查:如果账号已限流,直接返回切换信号 if p.requestedModel != "" { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) - return nil, &AntigravityAccountSwitchError{ - OriginalAccountID: p.account.ID, - RateLimitedModel: p.requestedModel, - IsStickySession: p.isStickySession, + // 单账号 503 退避重试模式:跳过限流预检查,直接发请求。 + // 首次请求设的限流是为了多账号调度器跳过该账号,在单账号模式下无意义。 + // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace + // 会在 Service 层原地等待+重试,不需要在预检查这里等。 + if isSingleAccountRetry(p.ctx) { + log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + } else { + log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + return nil, &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: p.requestedModel, + IsStickySession: p.isStickySession, + } } } } - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs + baseURL := resolveAntigravityForwardBaseURL() + if baseURL == "" { + return nil, errors.New("no antigravity forward base url configured") } + availableURLs := []string{baseURL} var resp *http.Response var usedBaseURL string @@ -371,12 +632,12 @@ urlFallbackLoop: _ = resp.Body.Close() // ★ 统一入口:自定义错误码 + 临时不可调度 - if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if handled, outStatus, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { if policyErr != nil { return nil, policyErr } resp = &http.Response{ - StatusCode: resp.StatusCode, + StatusCode: outStatus, Header: resp.Header.Clone(), Body: io.NopCloser(bytes.NewReader(respBody)), } @@ -610,21 +871,22 @@ func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, accoun return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) } -// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 -func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环及应返回的状态码。 +// ErrorPolicySkipped 时 outStatus 为 500(前端约定:未命中的错误返回 500)。 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, outStatus int, retErr error) { switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { case ErrorPolicySkipped: - return true, nil + return true, http.StatusInternalServerError, nil case ErrorPolicyMatched: _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - return true, nil + return true, statusCode, nil case ErrorPolicyTempUnscheduled: slog.Info("temp_unschedulable_matched", "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) - return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + return true, statusCode, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} } - return false, nil + return false, statusCode, nil } // mapAntigravityModel 获取映射后的模型名 @@ -734,11 +996,11 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account proxyURL = account.Proxy.URL() } - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + baseURL := resolveAntigravityForwardBaseURL() + if baseURL == "" { + return nil, errors.New("no antigravity forward base url configured") } + availableURLs := []string{baseURL} var lastErr error for urlIdx, baseURL := range availableURLs { @@ -1047,7 +1309,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 - thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) // 获取 access_token @@ -1203,7 +1465,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, break } - retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 8<<10)) _ = retryResp.Body.Close() if retryResp.StatusCode == http.StatusTooManyRequests { retryBaseURL := "" @@ -1284,6 +1546,27 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) + // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover + if resp.StatusCode == http.StatusBadRequest { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if isGoogleProjectConfigError(msg) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + upstreamDetail := s.getUpstreamErrorDetail(respBody) + log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true} + } + } + if s.shouldFailoverUpstreamError(resp.StatusCode) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1824,6 +2107,22 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // Always record upstream context for Ops error logs, even when we will failover. setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + // 精确匹配服务端配置类 400 错误,触发同账号重试 + failover + if resp.StatusCode == http.StatusBadRequest && isGoogleProjectConfigError(strings.ToLower(upstreamMsg)) { + log.Printf("%s status=400 google_config_error failover=true upstream_message=%q account=%d", prefix, upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps, RetryableOnSameAccount: true} + } + if s.shouldFailoverUpstreamError(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -1919,6 +2218,44 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) } } +// isGoogleProjectConfigError 判断(已提取的小写)错误消息是否属于 Google 服务端配置类问题。 +// 只精确匹配已知的服务端侧错误,避免对客户端请求错误做无意义重试。 +// 适用于所有走 Google 后端的平台(Antigravity、Gemini)。 +func isGoogleProjectConfigError(lowerMsg string) bool { + // Google 间歇性 Bug:Project ID 有效但被临时识别失败 + return strings.Contains(lowerMsg, "invalid project resource name") +} + +// googleConfigErrorCooldown 服务端配置类 400 错误的临时封禁时长 +const googleConfigErrorCooldown = 1 * time.Minute + +// tempUnscheduleGoogleConfigError 对服务端配置类 400 错误触发临时封禁, +// 避免短时间内反复调度到同一个有问题的账号。 +func tempUnscheduleGoogleConfigError(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) { + until := time.Now().Add(googleConfigErrorCooldown) + reason := "400: invalid project resource name (auto temp-unschedule 1m)" + if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil { + log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err) + } else { + log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason) + } +} + +// emptyResponseCooldown 空流式响应的临时封禁时长 +const emptyResponseCooldown = 1 * time.Minute + +// tempUnscheduleEmptyResponse 对空流式响应触发临时封禁, +// 避免短时间内反复调度到同一个返回空响应的账号。 +func tempUnscheduleEmptyResponse(ctx context.Context, repo AccountRepository, accountID int64, logPrefix string) { + until := time.Now().Add(emptyResponseCooldown) + reason := "empty stream response (auto temp-unschedule 1m)" + if err := repo.SetTempUnschedulable(ctx, accountID, until, reason); err != nil { + log.Printf("%s temp_unschedule_failed account=%d error=%v", logPrefix, accountID, err) + } else { + log.Printf("%s temp_unscheduled account=%d until=%v reason=%q", logPrefix, accountID, until.Format("15:04:05"), reason) + } +} + // sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待 // 返回 true 表示正常完成等待,false 表示 context 已取消 func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { @@ -1935,14 +2272,22 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { sleepFor = 0 } + timer := time.NewTimer(sleepFor) select { case <-ctx.Done(): + timer.Stop() return false - case <-time.After(sleepFor): + case <-timer.C: return true } } +// isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 +func isSingleAccountRetry(ctx context.Context) bool { + v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool) + return v +} + // setModelRateLimitByModelName 使用官方模型 ID 设置模型级限流 // 直接使用上游返回的模型 ID(如 claude-sonnet-4-5)作为限流 key // 返回是否已成功设置(若模型名为空或 repo 为 nil 将返回 false) @@ -1977,8 +2322,9 @@ func antigravityFallbackCooldownSeconds() (time.Duration, bool) { // antigravitySmartRetryInfo 智能重试所需的信息 type antigravitySmartRetryInfo struct { - RetryDelay time.Duration // 重试延迟时间 - ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") + RetryDelay time.Duration // 重试延迟时间 + ModelName string // 限流的模型名称(如 "claude-sonnet-4-5") + IsModelCapacityExhausted bool // 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED) } // parseAntigravitySmartRetryInfo 解析 Google RPC RetryInfo 和 ErrorInfo 信息 @@ -2093,31 +2439,40 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { } return &antigravitySmartRetryInfo{ - RetryDelay: retryDelay, - ModelName: modelName, + RetryDelay: retryDelay, + ModelName: modelName, + IsModelCapacityExhausted: hasModelCapacityExhausted, } } // shouldTriggerAntigravitySmartRetry 判断是否应该触发智能重试 // 返回: -// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold) -// - shouldRateLimitModel: 是否应该限流模型(retryDelay >= antigravityRateLimitThreshold) -// - waitDuration: 等待时间(智能重试时使用,shouldRateLimitModel=true 时为 0) +// - shouldRetry: 是否应该智能重试(retryDelay < antigravityRateLimitThreshold,或 MODEL_CAPACITY_EXHAUSTED) +// - shouldRateLimitModel: 是否应该限流模型并切换账号(仅 RATE_LIMIT_EXCEEDED 且 retryDelay >= 阈值) +// - waitDuration: 等待时间 // - modelName: 限流的模型名称 -func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string) { +// - isModelCapacityExhausted: 是否为模型容量不足(MODEL_CAPACITY_EXHAUSTED) +func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shouldRetry bool, shouldRateLimitModel bool, waitDuration time.Duration, modelName string, isModelCapacityExhausted bool) { if account.Platform != PlatformAntigravity { - return false, false, 0, "" + return false, false, 0, "", false } info := parseAntigravitySmartRetryInfo(respBody) if info == nil { - return false, false, 0, "" + return false, false, 0, "", false } + // MODEL_CAPACITY_EXHAUSTED(模型容量不足):所有账号共享同一模型容量池 + // 切换账号无意义,使用固定 1s 间隔重试 + if info.IsModelCapacityExhausted { + return true, false, antigravityModelCapacityRetryWait, info.ModelName, true + } + + // RATE_LIMIT_EXCEEDED(账号级限流): // retryDelay >= 阈值:直接限流模型,不重试 // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s if info.RetryDelay >= antigravityRateLimitThreshold { - return false, true, info.RetryDelay, info.ModelName + return false, true, info.RetryDelay, info.ModelName, false } // retryDelay < 阈值:智能重试 @@ -2126,7 +2481,7 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou waitDuration = antigravitySmartRetryMinWait } - return true, false, waitDuration, info.ModelName + return true, false, waitDuration, info.ModelName, false } // handleModelRateLimitParams 模型级限流处理参数 @@ -2152,8 +2507,9 @@ type handleModelRateLimitResult struct { // handleModelRateLimit 处理模型级限流(在原有逻辑之前调用) // 仅处理 429/503,解析模型名和 retryDelay -// - retryDelay < antigravityRateLimitThreshold: 返回 ShouldRetry=true,由调用方等待后重试 -// - retryDelay >= antigravityRateLimitThreshold: 设置模型限流 + 清除粘性会话 + 返回 SwitchError +// - MODEL_CAPACITY_EXHAUSTED: 返回 Handled=true(实际重试由 handleSmartRetry 处理) +// - RATE_LIMIT_EXCEEDED + retryDelay < 阈值: 返回 ShouldRetry=true,由调用方等待后重试 +// - RATE_LIMIT_EXCEEDED + retryDelay >= 阈值: 设置模型限流 + 清除粘性会话 + 返回 SwitchError func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimitParams) *handleModelRateLimitResult { if p.statusCode != 429 && p.statusCode != 503 { return &handleModelRateLimitResult{Handled: false} @@ -2164,7 +2520,17 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit return &handleModelRateLimitResult{Handled: false} } - // < antigravityRateLimitThreshold: 等待后重试 + // MODEL_CAPACITY_EXHAUSTED:模型容量不足,所有账号共享同一容量池 + // 切换账号无意义,不设置模型限流(实际重试由 handleSmartRetry 处理) + if info.IsModelCapacityExhausted { + log.Printf("%s status=%d model_capacity_exhausted model=%s (not switching account, retry handled by smart retry)", + p.prefix, p.statusCode, info.ModelName) + return &handleModelRateLimitResult{ + Handled: true, + } + } + + // RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试 if info.RetryDelay < antigravityRateLimitThreshold { log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", p.prefix, p.statusCode, info.ModelName, info.RetryDelay) @@ -2175,7 +2541,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit } } - // >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 + // RATE_LIMIT_EXCEEDED: >= antigravityRateLimitThreshold: 设置限流 + 清除粘性会话 + 切换账号 s.setModelRateLimitAndClearSession(p, info) return &handleModelRateLimitResult{ @@ -2242,6 +2608,10 @@ func (s *AntigravityGatewayService) handleUpstreamError( requestedModel string, groupID int64, sessionHash string, isStickySession bool, ) *handleModelRateLimitResult { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return nil + } // 模型级限流处理(优先) result := s.handleModelRateLimit(&handleModelRateLimitParams{ ctx: ctx, @@ -2719,9 +3089,14 @@ returnResponse: // 选择最后一个有效响应 finalResponse := pickGeminiCollectResult(last, lastWithParts) - // 处理空响应情况 + // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") + log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } // 如果收集到了图片 parts,需要合并到最终响应中 @@ -2939,6 +3314,21 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) } + // 检查错误透传规则 + if ptStatus, ptErrType, ptErrMsg, matched := applyErrorPassthroughRule( + c, account.Platform, upstreamStatus, body, + 0, "", "", + ); matched { + c.JSON(ptStatus, gin.H{ + "type": "error", + "error": gin.H{"type": ptErrType, "message": ptErrMsg}, + }) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", upstreamStatus) + } + return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) + } + var statusCode int var errType, errMsg string @@ -3134,10 +3524,14 @@ returnResponse: // 选择最后一个有效响应 finalResponse := pickGeminiCollectResult(last, lastWithParts) - // 处理空响应情况 + // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream") + log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), + RetryableOnSameAccount: true, + } } // 将收集的所有 parts 合并到最终响应中 @@ -3717,6 +4111,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } // extractClaudeUsage 从非流式 Claude 响应提取 usage @@ -3739,6 +4142,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage if v, ok := u["cache_creation_input_tokens"].(float64); ok { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } return usage } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index a6a349c1..b312e5ca 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -553,6 +553,75 @@ func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { require.NotContains(t, body, "event: error") } +// TestHandleGeminiStreamingResponse_ThoughtsTokenCount +// 验证:Gemini 流式转发时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleGeminiStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":30,"thoughtsTokenCount":80,"cachedContentTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=100, cachedContentTokenCount=10 → InputTokens=90 + require.Equal(t, 90, result.usage.InputTokens) + // candidatesTokenCount=30 + thoughtsTokenCount=80 → OutputTokens=110 + require.Equal(t, 110, result.usage.OutputTokens) + require.Equal(t, 10, result.usage.CacheReadInputTokens) +} + +// TestHandleClaudeStreamingResponse_ThoughtsTokenCount +// 验证:Gemini→Claude 流式转换时 thoughtsTokenCount 被计入 OutputTokens +func TestHandleClaudeStreamingResponse_ThoughtsTokenCount(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":50,"candidatesTokenCount":10,"thoughtsTokenCount":25}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "gemini-2.5-pro") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + // promptTokenCount=50 → InputTokens=50 + require.Equal(t, 50, result.usage.InputTokens) + // candidatesTokenCount=10 + thoughtsTokenCount=25 → OutputTokens=35 + require.Equal(t, 35, result.usage.OutputTokens) +} + // --- 流式客户端断开检测测试 --- // TestStreamUpstreamResponse_ClientDisconnectDrainsUsage diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index fa8379ed..b67c7faf 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -192,6 +192,43 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr) } +// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id) +func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) { + var proxyURL string + if proxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + // 刷新 token + tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL) + if err != nil { + return nil, err + } + + // 获取用户信息(email) + client := antigravity.NewClient(proxyURL) + userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) + if err != nil { + fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) + } else { + tokenInfo.Email = userInfo.Email + } + + // 获取 project_id(容错,失败不阻塞) + projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) + if loadErr != nil { + fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) + tokenInfo.ProjectIDMissing = true + } else { + tokenInfo.ProjectID = projectID + } + + return tokenInfo, nil +} + func isNonRetryableAntigravityOAuthError(err error) bool { msg := err.Error() nonRetryable := []string{ @@ -273,12 +310,21 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac } client := antigravity.NewClient(proxyURL) - loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { return loadResp.CloudAICompanionProject, nil } + if err == nil { + if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" { + return projectID, nil + } else if onboardErr != nil { + lastErr = onboardErr + continue + } + } + // 记录错误 if err != nil { lastErr = err @@ -292,6 +338,65 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) } +func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) { + tierID := resolveDefaultTierID(loadRaw) + if tierID == "" { + return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier") + } + + projectID, err := client.OnboardUser(ctx, accessToken, tierID) + if err != nil { + return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err) + } + return projectID, nil +} + +func resolveDefaultTierID(loadRaw map[string]any) string { + if len(loadRaw) == 0 { + return "" + } + + rawTiers, ok := loadRaw["allowedTiers"] + if !ok { + return "" + } + + tiers, ok := rawTiers.([]any) + if !ok { + return "" + } + + for _, rawTier := range tiers { + tier, ok := rawTier.(map[string]any) + if !ok { + continue + } + if isDefault, _ := tier["isDefault"].(bool); !isDefault { + continue + } + if id, ok := tier["id"].(string); ok { + id = strings.TrimSpace(id) + if id != "" { + return id + } + } + } + + return "" +} + +// FillProjectID 仅获取 project_id,不刷新 OAuth token +func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) { + var proxyURL string + if account.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3) +} + // BuildAccountCredentials 构建账户凭证 func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { creds := map[string]any{ diff --git a/backend/internal/service/antigravity_oauth_service_test.go b/backend/internal/service/antigravity_oauth_service_test.go new file mode 100644 index 00000000..1d2d8235 --- /dev/null +++ b/backend/internal/service/antigravity_oauth_service_test.go @@ -0,0 +1,82 @@ +package service + +import ( + "testing" +) + +func TestResolveDefaultTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + loadRaw map[string]any + want string + }{ + { + name: "nil loadRaw", + loadRaw: nil, + want: "", + }, + { + name: "missing allowedTiers", + loadRaw: map[string]any{ + "paidTier": map[string]any{"id": "g1-pro-tier"}, + }, + want: "", + }, + { + name: "empty allowedTiers", + loadRaw: map[string]any{"allowedTiers": []any{}}, + want: "", + }, + { + name: "tier missing id field", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"isDefault": true}, + }, + }, + want: "", + }, + { + name: "allowedTiers but no default", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": "free-tier", "isDefault": false}, + map[string]any{"id": "standard-tier", "isDefault": false}, + }, + }, + want: "", + }, + { + name: "default tier found", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": "free-tier", "isDefault": true}, + map[string]any{"id": "standard-tier", "isDefault": false}, + }, + }, + want: "free-tier", + }, + { + name: "default tier id with spaces", + loadRaw: map[string]any{ + "allowedTiers": []any{ + map[string]any{"id": " standard-tier ", "isDefault": true}, + }, + }, + want: "standard-tier", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := resolveDefaultTierID(tc.loadRaw) + if got != tc.want { + t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 59cc9331..0befa7d9 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -86,7 +86,9 @@ func (s *stubAntigravityAccountRepo) SetModelRateLimit(ctx context.Context, id i return nil } -func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { +func TestAntigravityRetryLoop_NoURLFallback_UsesConfiguredBaseURL(t *testing.T) { + t.Setenv(antigravityForwardBaseURLEnv, "") + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) oldAvailability := antigravity.DefaultURLAvailability defer func() { @@ -131,15 +133,16 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { require.NotNil(t, result) require.NotNil(t, result.resp) defer func() { _ = result.resp.Body.Close() }() - require.Equal(t, http.StatusOK, result.resp.StatusCode) - require.False(t, handleErrorCalled) - require.Len(t, upstream.calls, 2) - require.True(t, strings.HasPrefix(upstream.calls[0], base1)) - require.True(t, strings.HasPrefix(upstream.calls[1], base2)) + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.True(t, handleErrorCalled) + require.Len(t, upstream.calls, antigravityMaxRetries) + for _, callURL := range upstream.calls { + require.True(t, strings.HasPrefix(callURL, base1)) + } available := antigravity.DefaultURLAvailability.GetAvailableURLs() require.NotEmpty(t, available) - require.Equal(t, base2, available[0]) + require.Equal(t, base1, available[0]) } // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 @@ -188,13 +191,14 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } -// TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 -func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { +// TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景 +// MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号 +func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 3, Name: "acc-3", Platform: PlatformAntigravity} - // 503 + MODEL_CAPACITY_EXHAUSTED → 模型限流 + // 503 + MODEL_CAPACITY_EXHAUSTED → 等待重试,不切换账号 body := []byte(`{ "error": { "status": "UNAVAILABLE", @@ -207,13 +211,13 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) - // 应该触发模型限流 + // MODEL_CAPACITY_EXHAUSTED 应该标记为已处理,不切换账号,不设置模型限流 + // 实际重试由 handleSmartRetry 处理 require.NotNil(t, result) require.True(t, result.Handled) - require.NotNil(t, result.SwitchError) - require.Equal(t, "gemini-3-pro-high", result.SwitchError.RateLimitedModel) - require.Len(t, repo.modelRateLimitCalls, 1) - require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) + require.False(t, result.ShouldRetry, "MODEL_CAPACITY_EXHAUSTED should not trigger retry from handleModelRateLimit path") + require.Nil(t, result.SwitchError, "MODEL_CAPACITY_EXHAUSTED should not trigger account switch") + require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit") } // TestHandleUpstreamError_503_NonModelRateLimit 测试 503 非模型限流场景(不处理) @@ -301,11 +305,12 @@ func TestParseGeminiRateLimitResetTime_QuotaResetDelay_RoundsUp(t *testing.T) { func TestParseAntigravitySmartRetryInfo(t *testing.T) { tests := []struct { - name string - body string - expectedDelay time.Duration - expectedModel string - expectedNil bool + name string + body string + expectedDelay time.Duration + expectedModel string + expectedNil bool + expectedIsModelCapacityExhausted bool }{ { name: "valid complete response with RATE_LIMIT_EXCEEDED", @@ -368,8 +373,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { "message": "No capacity available for model gemini-3-pro-high on the server" } }`, - expectedDelay: 39 * time.Second, - expectedModel: "gemini-3-pro-high", + expectedDelay: 39 * time.Second, + expectedModel: "gemini-3-pro-high", + expectedIsModelCapacityExhausted: true, }, { name: "503 UNAVAILABLE without MODEL_CAPACITY_EXHAUSTED - should return nil", @@ -480,6 +486,9 @@ func TestParseAntigravitySmartRetryInfo(t *testing.T) { if result.ModelName != tt.expectedModel { t.Errorf("ModelName = %q, want %q", result.ModelName, tt.expectedModel) } + if result.IsModelCapacityExhausted != tt.expectedIsModelCapacityExhausted { + t.Errorf("IsModelCapacityExhausted = %v, want %v", result.IsModelCapacityExhausted, tt.expectedIsModelCapacityExhausted) + } }) } } @@ -491,13 +500,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { apiKeyAccount := &Account{Type: AccountTypeAPIKey} tests := []struct { - name string - account *Account - body string - expectedShouldRetry bool - expectedShouldRateLimit bool - minWait time.Duration - modelName string + name string + account *Account + body string + expectedShouldRetry bool + expectedShouldRateLimit bool + expectedIsModelCapacityExhausted bool + minWait time.Duration + modelName string }{ { name: "OAuth account with short delay (< 7s) - smart retry", @@ -611,13 +621,14 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { ] } }`, - expectedShouldRetry: false, - expectedShouldRateLimit: true, - minWait: 39 * time.Second, - modelName: "gemini-3-pro-high", + expectedShouldRetry: true, + expectedShouldRateLimit: false, + expectedIsModelCapacityExhausted: true, + minWait: 1 * time.Second, + modelName: "gemini-3-pro-high", }, { - name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use default rate limit", + name: "503 UNAVAILABLE with MODEL_CAPACITY_EXHAUSTED - no retryDelay - use fixed wait", account: oauthAccount, body: `{ "error": { @@ -629,10 +640,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { "message": "No capacity available for model gemini-2.5-flash on the server" } }`, - expectedShouldRetry: false, - expectedShouldRateLimit: true, - minWait: 30 * time.Second, - modelName: "gemini-2.5-flash", + expectedShouldRetry: true, + expectedShouldRateLimit: false, + expectedIsModelCapacityExhausted: true, + minWait: 1 * time.Second, + modelName: "gemini-2.5-flash", }, { name: "429 RESOURCE_EXHAUSTED with RATE_LIMIT_EXCEEDED - no retryDelay - use default rate limit", @@ -656,13 +668,16 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - shouldRetry, shouldRateLimit, wait, model := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) + shouldRetry, shouldRateLimit, wait, model, isModelCapacityExhausted := shouldTriggerAntigravitySmartRetry(tt.account, []byte(tt.body)) if shouldRetry != tt.expectedShouldRetry { t.Errorf("shouldRetry = %v, want %v", shouldRetry, tt.expectedShouldRetry) } if shouldRateLimit != tt.expectedShouldRateLimit { t.Errorf("shouldRateLimit = %v, want %v", shouldRateLimit, tt.expectedShouldRateLimit) } + if isModelCapacityExhausted != tt.expectedIsModelCapacityExhausted { + t.Errorf("isModelCapacityExhausted = %v, want %v", isModelCapacityExhausted, tt.expectedIsModelCapacityExhausted) + } if shouldRetry { if wait < tt.minWait { t.Errorf("wait = %v, want >= %v", wait, tt.minWait) @@ -915,6 +930,22 @@ func TestIsAntigravityAccountSwitchError(t *testing.T) { } } +func TestResolveAntigravityForwardBaseURL_DefaultDaily(t *testing.T) { + t.Setenv(antigravityForwardBaseURLEnv, "") + + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + defer func() { + antigravity.BaseURLs = oldBaseURLs + }() + + prodURL := "https://prod.test" + dailyURL := "https://daily.test" + antigravity.BaseURLs = []string{dailyURL, prodURL} + + resolved := resolveAntigravityForwardBaseURL() + require.Equal(t, dailyURL, resolved) +} + func TestAntigravityAccountSwitchError_Error(t *testing.T) { err := &AntigravityAccountSwitchError{ OriginalAccountID: 789, diff --git a/backend/internal/service/antigravity_single_account_retry_test.go b/backend/internal/service/antigravity_single_account_retry_test.go new file mode 100644 index 00000000..8b01cc31 --- /dev/null +++ b/backend/internal/service/antigravity_single_account_retry_test.go @@ -0,0 +1,904 @@ +//go:build unit + +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// 辅助函数:构造带 SingleAccountRetry 标记的 context +// --------------------------------------------------------------------------- + +func ctxWithSingleAccountRetry() context.Context { + return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true) +} + +// --------------------------------------------------------------------------- +// 1. isSingleAccountRetry 测试 +// --------------------------------------------------------------------------- + +func TestIsSingleAccountRetry_True(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true) + require.True(t, isSingleAccountRetry(ctx)) +} + +func TestIsSingleAccountRetry_False_NoValue(t *testing.T) { + require.False(t, isSingleAccountRetry(context.Background())) +} + +func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false) + require.False(t, isSingleAccountRetry(ctx)) +} + +func TestIsSingleAccountRetry_False_WrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true") + require.False(t, isSingleAccountRetry(ctx)) +} + +// --------------------------------------------------------------------------- +// 2. 常量验证 +// --------------------------------------------------------------------------- + +func TestSingleAccountRetryConstants(t *testing.T) { + require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts, + "单账号原地重试最多 3 次") + require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait, + "单次最大等待 15s") + require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait, + "总累计等待不超过 30s") +} + +// --------------------------------------------------------------------------- +// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace +// (而非设模型限流 + 切换账号) +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace +// 核心场景:503 + retryDelay >= 7s + SingleAccountRetry 标记 +// → 不设模型限流、不切换账号,改为原地重试 +func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) { + // 原地重试成功 + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 1, + Name: "acc-single", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + // 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ], + "message": "No capacity available for model gemini-3-pro-high on the server" + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键断言:返回 resp(原地重试成功),而非 switchError(切换账号) + require.NotNil(t, result.resp, "should return successful response from in-place retry") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT return switchError in single account mode") + require.Nil(t, result.err) + + // 验证未设模型限流(单账号模式不应设限流) + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit in single account retry mode") + + // 验证确实调用了 upstream(原地重试) + require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call") +} + +// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches +// 对照组:503 + retryDelay >= 7s + 无 SingleAccountRetry 标记 +// → 照常设模型限流 + 切换账号 +func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 2, + Name: "acc-multi", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED, + // 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel) + respBody := []byte(`{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), // 关键:无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 对照:多账号模式返回 switchError + require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503") + require.Nil(t, result.resp, "should not return resp when switchError is set") + + // 对照:多账号模式应设模型限流 + require.Len(t, repo.modelRateLimitCalls, 1, + "multi-account mode SHOULD set model rate limit") +} + +// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches +// 边界情况:429(非 503)+ SingleAccountRetry 标记 +// → 单账号原地重试仅针对 503,429 依然走切换账号逻辑 +func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 3, + Name: "acc-429", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 429 + 15s >= 7s 阈值 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, // 429,不是 503 + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), // 有单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 429 即使有单账号标记,也应走切换账号 + require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry") + require.Len(t, repo.modelRateLimitCalls, 1, + "429 should still set model rate limit even with SingleAccountRetry") +} + +// --------------------------------------------------------------------------- +// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流 +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit +// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流 +func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) { + // 智能重试也返回 503 + failRespBody := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 4, + Name: "acc-short-503", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 0.1s < 7s 阈值 + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换 + require.NotNil(t, result.resp, "should return 503 response directly for single account mode") + require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT switch account in single account mode") + + // 关键断言:不设模型限流 + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit for 503 in single account mode") +} + +// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit +// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流 +// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径 +func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) { + failRespBody := `{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 5, + Name: "acc-multi-503", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), // 无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 对照:多账号模式应返回 switchError + require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503") + // 对照:多账号模式应设模型限流 + require.Len(t, repo.modelRateLimitCalls, 1, + "multi-account mode should set model rate limit") +} + +// --------------------------------------------------------------------------- +// 5. handleSingleAccountRetryInPlace 直接测试 +// --------------------------------------------------------------------------- + +// TestHandleSingleAccountRetryInPlace_Success 原地重试成功 +func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + account := &Account{ + ID: 10, + Name: "acc-inplace-ok", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not switch account on success") + require.Nil(t, result.err) +} + +// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503(不设限流) +func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) { + // 构造 3 个 503 响应(对应 3 次原地重试) + var responses []*http.Response + var errors []error + for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ { + responses = append(responses, &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`)), + }) + errors = append(errors, nil) + } + upstream := &mockSmartRetryUpstream{ + responses: responses, + errors: errors, + } + + account := &Account{ + ID: 11, + Name: "acc-inplace-fail", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{"X-Test": {"original"}}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + // 关键:返回 503 resp,不返回 switchError + require.NotNil(t, result.resp, "should return 503 response directly") + require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode) + require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it") + require.Nil(t, result.err) + + // 验证确实重试了指定次数 + require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts, + "should have made exactly maxAttempts retry calls") +} + +// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围 +func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) { + // 用短延迟的成功响应,只验证不 panic + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + account := &Account{ + ID: 12, + Name: "acc-clamp", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + + // 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro") + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp) + require.Equal(t, http.StatusOK, result.resp.StatusCode) +} + +// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回 +func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) { + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil}, + errors: []error{nil}, + } + + account := &Account{ + ID: 13, + Name: "acc-cancel", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + cancel() // 立即取消 + + params := antigravityRetryLoopParams{ + ctx: ctx, + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Error(t, result.err, "should return context error") + // 不应调用 upstream(因为在等待阶段就被取消了) + require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled") +} + +// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试 +func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + // 第1次网络错误(nil resp),第2次成功 + responses: []*http.Response{nil, successResp}, + errors: []error{nil, nil}, + } + + account := &Account{ + ID: 14, + Name: "acc-net-retry", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response after network error recovery") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds") +} + +// --------------------------------------------------------------------------- +// 6. antigravityRetryLoop 预检查:单账号模式跳过限流 +// --------------------------------------------------------------------------- + +// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit +// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求 +func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) { + // 创建一个已设模型限流的账号 + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 20, + Name: "acc-rate-limited", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "should return result") + require.NotNil(t, result.resp, "should have response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + // 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream + require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit") +} + +// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit +// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError +func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) { + upstream := &recordingOKUpstream{} + account := &Account{ + ID: 21, + Name: "acc-rate-limited-multi", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Extra: map[string]any{ + modelRateLimitsKey: map[string]any{ + "claude-sonnet-4-5": map[string]any{ + "rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339), + }, + }, + }, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), // 无单账号标记 + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result on rate limit switch") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + + // upstream 不应被调用(预检查就短路了) + require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks") +} + +// --------------------------------------------------------------------------- +// 7. 端到端集成场景测试 +// --------------------------------------------------------------------------- + +// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E +// 端到端场景:503 + 单账号 + 原地重试第2次成功 +func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) { + // 第1次原地重试仍返回 503,第2次成功 + fail503Body := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + resp503 := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(fail503Body)), + } + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{resp503, successResp}, + errors: []error{nil, nil}, + } + + account := &Account{ + ID: 30, + Name: "acc-e2e", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Concurrency: 1, + } + + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + } + + params := antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro") + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response after 2nd attempt") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError) + require.Len(t, upstream.calls, 2, "first 503, second OK") +} + +// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E +// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路 +func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) { + // 初始请求返回 503 + 长延迟 + initial503Body := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"} + ], + "message": "No capacity available" + } + }`) + initial503Resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(initial503Body)), + } + + // 原地重试成功 + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + + upstream := &mockSmartRetryUpstream{ + // 第1次调用(retryLoop 主循环)返回 503 + // 第2次调用(handleSingleAccountRetryInPlace 原地重试)返回 200 + responses: []*http.Response{initial503Resp, successResp}, + errors: []error{nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 31, + Name: "acc-e2e-loop", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctxWithSingleAccountRetry(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.NoError(t, err, "should not return error on successful retry") + require.NotNil(t, result, "should return result") + require.NotNil(t, result.resp, "should return response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + + // 验证未设模型限流 + require.Len(t, repo.modelRateLimitCalls, 0, + "should NOT set model rate limit in single account retry mode") +} diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index a7e0d296..432c80e5 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -294,8 +294,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") } -// TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError -func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testing.T) { +// TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess 测试 503 MODEL_CAPACITY_EXHAUSTED 重试成功 +// MODEL_CAPACITY_EXHAUSTED 使用固定 1s 间隔重试,不切换账号 +func TestHandleSmartRetry_503_ModelCapacityExhausted_RetrySuccess(t *testing.T) { repo := &stubAntigravityAccountRepo{} account := &Account{ ID: 3, @@ -304,7 +305,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi Platform: PlatformAntigravity, } - // 503 + MODEL_CAPACITY_EXHAUSTED + 39s >= 7s 阈值 + // 503 + MODEL_CAPACITY_EXHAUSTED + 39s(上游 retryDelay 应被忽略,使用固定 1s) respBody := []byte(`{ "error": { "code": 503, @@ -322,6 +323,14 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi Body: io.NopCloser(bytes.NewReader(respBody)), } + // mock: 第 1 次重试返回 200 成功 + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{ + {StatusCode: http.StatusOK, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(`{"ok":true}`))}, + }, + errors: []error{nil}, + } + params := antigravityRetryLoopParams{ ctx: context.Background(), prefix: "[test]", @@ -330,6 +339,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi action: "generateContent", body: []byte(`{"input":"test"}`), accountRepo: repo, + httpUpstream: upstream, isStickySession: true, handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil @@ -343,16 +353,67 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) - require.Nil(t, result.resp) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) require.Nil(t, result.err) - require.NotNil(t, result.switchError, "should return switchError for 503 model capacity exhausted") - require.Equal(t, account.ID, result.switchError.OriginalAccountID) - require.Equal(t, "gemini-3-pro-high", result.switchError.RateLimitedModel) - require.True(t, result.switchError.IsStickySession) + require.Nil(t, result.switchError, "MODEL_CAPACITY_EXHAUSTED should not return switchError") - // 验证模型限流已设置 - require.Len(t, repo.modelRateLimitCalls, 1) - require.Equal(t, "gemini-3-pro-high", repo.modelRateLimitCalls[0].modelKey) + // 不应设置模型限流 + require.Empty(t, repo.modelRateLimitCalls, "MODEL_CAPACITY_EXHAUSTED should not set model rate limit") + require.Len(t, upstream.calls, 1, "should have made one retry call before success") +} + +// TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel 测试 MODEL_CAPACITY_EXHAUSTED 上下文取消 +func TestHandleSmartRetry_503_ModelCapacityExhausted_ContextCancel(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 3, + Name: "acc-3", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + // 立即取消上下文,验证重试循环能正确退出 + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + params := antigravityRetryLoopParams{ + ctx: ctx, + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + svc := &AntigravityGatewayService{} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"}) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.Error(t, result.err, "should return context error") + require.Nil(t, result.switchError, "should not return switchError on context cancel") + require.Empty(t, repo.modelRateLimitCalls, "should not set model rate limit on context cancel") } // TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic 测试非 Antigravity 平台账号走默认逻辑 @@ -1129,20 +1190,20 @@ func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t } // TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession -// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 +// 429 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { failRespBody := `{ "error": { - "code": 503, - "status": "UNAVAILABLE", + "code": 429, + "status": "RESOURCE_EXHAUSTED", "details": [ - {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} ] } }` failResp := &http.Response{ - StatusCode: http.StatusServiceUnavailable, + StatusCode: http.StatusTooManyRequests, Header: http.Header{}, Body: io.NopCloser(strings.NewReader(failRespBody)), } @@ -1162,16 +1223,16 @@ func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession respBody := []byte(`{ "error": { - "code": 503, - "status": "UNAVAILABLE", + "code": 429, + "status": "RESOURCE_EXHAUSTED", "details": [ - {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "RATE_LIMIT_EXCEEDED"}, {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} ] } }`) resp := &http.Response{ - StatusCode: http.StatusServiceUnavailable, + StatusCode: http.StatusTooManyRequests, Header: http.Header{}, Body: io.NopCloser(bytes.NewReader(respBody)), } diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go index 1eb740f9..068d6a08 100644 --- a/backend/internal/service/antigravity_token_provider.go +++ b/backend/internal/service/antigravity_token_provider.go @@ -7,12 +7,14 @@ import ( "log/slog" "strconv" "strings" + "sync" "time" ) const ( antigravityTokenRefreshSkew = 3 * time.Minute antigravityTokenCacheSkew = 5 * time.Minute + antigravityBackfillCooldown = 5 * time.Minute ) // AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) @@ -23,6 +25,7 @@ type AntigravityTokenProvider struct { accountRepo AccountRepository tokenCache AntigravityTokenCache antigravityOAuthService *AntigravityOAuthService + backfillCooldown sync.Map // key: int64 (account.ID) → value: time.Time } func NewAntigravityTokenProvider( @@ -93,13 +96,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * if err != nil { return "", err } - newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) - for k, v := range account.Credentials { - if _, exists := newCredentials[k]; !exists { - newCredentials[k] = v - } - } - account.Credentials = newCredentials + p.mergeCredentials(account, tokenInfo) if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr) } @@ -113,6 +110,21 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return "", errors.New("access_token not found in credentials") } + // 如果账号还没有 project_id,尝试在线补齐,避免请求 daily/sandbox 时出现 + // "Invalid project resource name projects/"。 + // 仅调用 loadProjectIDWithRetry,不刷新 OAuth token;带冷却机制防止频繁重试。 + if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil { + if p.shouldAttemptBackfill(account.ID) { + p.markBackfillAttempted(account.ID) + if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" { + account.Credentials["project_id"] = projectID + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + log.Printf("[AntigravityTokenProvider] project_id 补齐持久化失败: %v", updateErr) + } + } + } + } + // 3. 存入缓存(验证版本后再写入,避免异步刷新任务与请求线程的竞态条件) if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) @@ -144,6 +156,31 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account * return accessToken, nil } +// mergeCredentials 将 tokenInfo 构建的凭证合并到 account 中,保留原有未覆盖的字段 +func (p *AntigravityTokenProvider) mergeCredentials(account *Account, tokenInfo *AntigravityTokenInfo) { + newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = newCredentials +} + +// shouldAttemptBackfill 检查是否应该尝试补齐 project_id(冷却期内不重复尝试) +func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool { + if v, ok := p.backfillCooldown.Load(accountID); ok { + if lastAttempt, ok := v.(time.Time); ok { + return time.Since(lastAttempt) > antigravityBackfillCooldown + } + } + return true +} + +func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) { + p.backfillCooldown.Store(accountID, time.Now()) +} + func AntigravityTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index db5a9708..6934bc64 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -31,8 +31,8 @@ type ModelPricing struct { OutputPricePerToken float64 // 每token输出价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 - CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) SupportsCacheBreakdown bool // 是否支持详细的缓存分类 } @@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { if s.pricingService != nil { litellmPricing := s.pricingService.GetModelPricing(model) if litellmPricing != nil { + // 启用 5m/1h 分类计费的条件: + // 1. 存在 1h 价格 + // 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费) + price5m := litellmPricing.CacheCreationInputTokenCost + price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr + enableBreakdown := price1h > 0 && price1h > price5m return &ModelPricing{ InputPricePerToken: litellmPricing.InputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - SupportsCacheBreakdown: false, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, }, nil } } @@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存) - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice + // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { + // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 + breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } else { + breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + } } else { // 标准缓存创建价格(per-token) breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken @@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage // 范围内部分:正常计费 inRangeTokens := UsageTokens{ - InputTokens: inRangeInputTokens, - OutputTokens: tokens.OutputTokens, // 输出只算一次 - CacheCreationTokens: tokens.CacheCreationTokens, - CacheReadTokens: inRangeCacheTokens, + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + CacheCreation5mTokens: tokens.CacheCreation5mTokens, + CacheCreation1hTokens: tokens.CacheCreation1hTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go index 65085d6f..011c3ce4 100644 --- a/backend/internal/service/error_passthrough_runtime.go +++ b/backend/internal/service/error_passthrough_runtime.go @@ -61,6 +61,11 @@ func applyErrorPassthroughRule( errMsg = *rule.CustomMessage } + // 命中 skip_monitoring 时在 context 中标记,供 ops_error_logger 跳过记录。 + if rule.SkipMonitoring { + c.Set(OpsSkipPassthroughKey, true) + } + // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 errType = "upstream_error" return status, errType, errMsg, true diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 393e6e59..0a45e57a 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -194,6 +194,63 @@ func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) { assert.Equal(t, "Gemini上游失败", errField["message"]) } +func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限") + rule.SkipMonitoring = true + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + BindErrorPassthroughService(c, ruleSvc) + + _, _, _, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusBadRequest, + []byte(`{"error":{"message":"prompt is too long"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.True(t, matched) + v, exists := c.Get(OpsSkipPassthroughKey) + assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true") + boolVal, ok := v.(bool) + assert.True(t, ok, "value should be bool") + assert.True(t, boolVal) +} + +func TestApplyErrorPassthroughRule_NoSkipMonitoringDoesNotSetContextKey(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + rule := newNonFailoverPassthroughRule(http.StatusBadRequest, "prompt is too long", http.StatusBadRequest, "上下文超限") + rule.SkipMonitoring = false + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + BindErrorPassthroughService(c, ruleSvc) + + _, _, _, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusBadRequest, + []byte(`{"error":{"message":"prompt is too long"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.True(t, matched) + _, exists := c.Get(OpsSkipPassthroughKey) + assert.False(t, exists, "OpsSkipPassthroughKey should NOT be set when skip_monitoring=false") +} + func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { return &model.ErrorPassthroughRule{ ID: 1, diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index c3e0f630..da8c9ccf 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -45,10 +45,20 @@ type ErrorPassthroughService struct { cache ErrorPassthroughCache // 本地内存缓存,用于快速匹配 - localCache []*model.ErrorPassthroughRule + localCache []*cachedPassthroughRule localCacheMu sync.RWMutex } +// cachedPassthroughRule 预计算的规则缓存,避免运行时重复 ToLower +type cachedPassthroughRule struct { + *model.ErrorPassthroughRule + lowerKeywords []string // 预计算的小写关键词 + lowerPlatforms []string // 预计算的小写平台 + errorCodeSet map[int]struct{} // 预计算的 error code set +} + +const maxBodyMatchLen = 8 << 10 // 8KB,错误信息不会在 8KB 之后才出现 + // NewErrorPassthroughService 创建错误透传规则服务 func NewErrorPassthroughService( repo ErrorPassthroughRepository, @@ -150,17 +160,19 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod return nil } - bodyStr := strings.ToLower(string(body)) + lowerPlatform := strings.ToLower(platform) + var bodyLower string // 延迟初始化,只在需要关键词匹配时计算 + var bodyLowerDone bool for _, rule := range rules { if !rule.Enabled { continue } - if !s.platformMatches(rule, platform) { + if !s.platformMatchesCached(rule, lowerPlatform) { continue } - if s.ruleMatches(rule, statusCode, bodyStr) { - return rule + if s.ruleMatchesOptimized(rule, statusCode, body, &bodyLower, &bodyLowerDone) { + return rule.ErrorPassthroughRule } } @@ -168,7 +180,7 @@ func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, bod } // getCachedRules 获取缓存的规则列表(按优先级排序) -func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { +func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule { s.localCacheMu.RLock() rules := s.localCache s.localCacheMu.RUnlock() @@ -223,17 +235,39 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { return nil } -// setLocalCache 设置本地缓存 +// setLocalCache 设置本地缓存,预计算小写值和 set 以避免运行时重复计算 func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + cached := make([]*cachedPassthroughRule, len(rules)) + for i, r := range rules { + cr := &cachedPassthroughRule{ErrorPassthroughRule: r} + if len(r.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(r.Keywords)) + for j, kw := range r.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(r.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(r.Platforms)) + for j, p := range r.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(r.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(r.ErrorCodes)) + for _, code := range r.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + cached[i] = cr + } + // 按优先级排序 - sorted := make([]*model.ErrorPassthroughRule, len(rules)) - copy(sorted, rules) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].Priority < sorted[j].Priority + sort.Slice(cached, func(i, j int) bool { + return cached[i].Priority < cached[j].Priority }) s.localCacheMu.Lock() - s.localCache = sorted + s.localCache = cached s.localCacheMu.Unlock() } @@ -273,62 +307,79 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { } } -// platformMatches 检查平台是否匹配 -func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { - // 如果没有配置平台限制,则匹配所有平台 - if len(rule.Platforms) == 0 { +// ensureBodyLower 延迟初始化 body 的小写版本,只做一次转换,限制 8KB +func ensureBodyLower(body []byte, bodyLower *string, done *bool) string { + if *done { + return *bodyLower + } + b := body + if len(b) > maxBodyMatchLen { + b = b[:maxBodyMatchLen] + } + *bodyLower = strings.ToLower(string(b)) + *done = true + return *bodyLower +} + +// platformMatchesCached 使用预计算的小写平台检查是否匹配 +func (s *ErrorPassthroughService) platformMatchesCached(rule *cachedPassthroughRule, lowerPlatform string) bool { + if len(rule.lowerPlatforms) == 0 { return true } - - platform = strings.ToLower(platform) - for _, p := range rule.Platforms { - if strings.ToLower(p) == platform { + for _, p := range rule.lowerPlatforms { + if p == lowerPlatform { return true } } - return false } -// ruleMatches 检查规则是否匹配 -func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { - hasErrorCodes := len(rule.ErrorCodes) > 0 - hasKeywords := len(rule.Keywords) > 0 +// ruleMatchesOptimized 优化的规则匹配,支持短路和延迟 body 转换 +func (s *ErrorPassthroughService) ruleMatchesOptimized(rule *cachedPassthroughRule, statusCode int, body []byte, bodyLower *string, bodyLowerDone *bool) bool { + hasErrorCodes := len(rule.errorCodeSet) > 0 + hasKeywords := len(rule.lowerKeywords) > 0 - // 如果没有配置任何条件,不匹配 if !hasErrorCodes && !hasKeywords { return false } - codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) - keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + codeMatch := !hasErrorCodes || s.containsIntSet(rule.errorCodeSet, statusCode) if rule.MatchMode == model.MatchModeAll { - // "all" 模式:所有配置的条件都必须满足 - return codeMatch && keywordMatch + // "all" 模式:所有配置的条件都必须满足,短路 + if hasErrorCodes && !codeMatch { + return false + } + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } - // "any" 模式:任一条件满足即可 + // "any" 模式:任一条件满足即可,短路 if hasErrorCodes && hasKeywords { - return codeMatch || keywordMatch + if codeMatch { + return true + } + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) } - return codeMatch && keywordMatch + // 只配置了一种条件 + if hasKeywords { + return s.containsAnyKeywordCached(ensureBodyLower(body, bodyLower, bodyLowerDone), rule.lowerKeywords) + } + return codeMatch } -// containsInt 检查切片是否包含指定整数 -func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { - for _, v := range slice { - if v == val { - return true - } - } - return false -} - -// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) -func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { - for _, kw := range keywords { - if strings.Contains(bodyLower, strings.ToLower(kw)) { +// containsIntSet 使用 map 查找替代线性扫描 +func (s *ErrorPassthroughService) containsIntSet(set map[int]struct{}, val int) bool { + _, ok := set[val] + return ok +} + +// containsAnyKeywordCached 使用预计算的小写关键词检查匹配 +func (s *ErrorPassthroughService) containsAnyKeywordCached(bodyLower string, lowerKeywords []string) bool { + for _, kw := range lowerKeywords { + if strings.Contains(bodyLower, kw) { return true } } diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 74c98d86..96ddd637 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -145,32 +145,58 @@ func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughServic return svc } +// newCachedRuleForTest 从 model.ErrorPassthroughRule 创建 cachedPassthroughRule(测试用) +func newCachedRuleForTest(rule *model.ErrorPassthroughRule) *cachedPassthroughRule { + cr := &cachedPassthroughRule{ErrorPassthroughRule: rule} + if len(rule.Keywords) > 0 { + cr.lowerKeywords = make([]string, len(rule.Keywords)) + for j, kw := range rule.Keywords { + cr.lowerKeywords[j] = strings.ToLower(kw) + } + } + if len(rule.Platforms) > 0 { + cr.lowerPlatforms = make([]string, len(rule.Platforms)) + for j, p := range rule.Platforms { + cr.lowerPlatforms[j] = strings.ToLower(p) + } + } + if len(rule.ErrorCodes) > 0 { + cr.errorCodeSet = make(map[int]struct{}, len(rule.ErrorCodes)) + for _, code := range rule.ErrorCodes { + cr.errorCodeSet[code] = struct{}{} + } + } + return cr +} + // ============================================================================= -// 测试 ruleMatches 核心匹配逻辑 +// 测试 ruleMatchesOptimized 核心匹配逻辑 // ============================================================================= func TestRuleMatches_NoConditions(t *testing.T) { // 没有配置任何条件时,不应该匹配 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) - assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + var bodyLower string + var bodyLowerDone bool + assert.False(t, svc.ruleMatchesOptimized(rule, 422, []byte("some error message"), &bodyLower, &bodyLowerDone), "没有配置条件时不应该匹配") } func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -186,7 +212,9 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -194,12 +222,12 @@ func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{"context limit", "model not supported"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -210,16 +238,14 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { {"关键词匹配 context limit", 500, "error: context limit reached", true}, {"关键词匹配 model not supported", 400, "the model not supported here", true}, {"关键词不匹配", 422, "some other error", false}, - // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 - // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches - {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + {"关键词大小写 - 自动转换", 500, "Context Limit exceeded", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 模拟 MatchRule 的行为:先转换为小写 - bodyLower := strings.ToLower(tt.body) - result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result) }) } @@ -228,12 +254,12 @@ func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { // any 模式:错误码 OR 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAny, - } + }) tests := []struct { name string @@ -274,7 +300,9 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } @@ -283,12 +311,12 @@ func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { func TestRuleMatches_BothConditions_AllMode(t *testing.T) { // all 模式:错误码 AND 关键词 svc := newTestService(nil) - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAll, - } + }) tests := []struct { name string @@ -329,14 +357,16 @@ func TestRuleMatches_BothConditions_AllMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := svc.ruleMatches(rule, tt.statusCode, tt.body) + var bodyLower string + var bodyLowerDone bool + result := svc.ruleMatchesOptimized(rule, tt.statusCode, []byte(tt.body), &bodyLower, &bodyLowerDone) assert.Equal(t, tt.expected, result, tt.reason) }) } } // ============================================================================= -// 测试 platformMatches 平台匹配逻辑 +// 测试 platformMatchesCached 平台匹配逻辑 // ============================================================================= func TestPlatformMatches(t *testing.T) { @@ -394,10 +424,10 @@ func TestPlatformMatches(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rule := &model.ErrorPassthroughRule{ + rule := newCachedRuleForTest(&model.ErrorPassthroughRule{ Platforms: tt.rulePlatforms, - } - result := svc.platformMatches(rule, tt.requestPlatform) + }) + result := svc.platformMatchesCached(rule, strings.ToLower(tt.requestPlatform)) assert.Equal(t, tt.expected, result) }) } diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go index 9f8ad938..a8b42a2c 100644 --- a/backend/internal/service/error_policy_integration_test.go +++ b/backend/internal/service/error_policy_integration_test.go @@ -116,7 +116,7 @@ func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { customCodes: []any{float64(500)}, expectHandleError: 0, expectUpstream: 1, - expectStatusCode: 429, + expectStatusCode: 500, }, { name: "500_in_custom_codes_matched", @@ -364,3 +364,109 @@ func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") } + +// --------------------------------------------------------------------------- +// epTrackingRepo — records SetRateLimited / SetError calls for verification. +// --------------------------------------------------------------------------- + +type epTrackingRepo struct { + mockAccountRepoForGemini + rateLimitedCalls int + rateLimitedID int64 + setErrCalls int + setErrID int64 + tempCalls int +} + +func (r *epTrackingRepo) SetRateLimited(_ context.Context, id int64, _ time.Time) error { + r.rateLimitedCalls++ + r.rateLimitedID = id + return nil +} + +func (r *epTrackingRepo) SetError(_ context.Context, id int64, _ string) error { + r.setErrCalls++ + r.setErrID = id + return nil +} + +func (r *epTrackingRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit +// +// 核心场景:自定义错误码设为 [599](一个不会真正出现的错误码), +// 当上游返回 429/500/503/401 时: +// - 返回给客户端的状态码必须是 500(而不是透传原始状态码) +// - 不调用 SetRateLimited(不进入限流状态) +// - 不调用 SetError(不停止调度) +// - 不调用 handleError +// --------------------------------------------------------------------------- + +func TestCustomErrorCode599_SkippedErrors_Return500_NoRateLimit(t *testing.T) { + errorCodes := []int{429, 500, 503, 401, 403} + + for _, upstreamStatus := range errorCodes { + t.Run(http.StatusText(upstreamStatus), func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{ + statusCode: upstreamStatus, + body: `{"error":"some upstream error"}`, + } + repo := &epTrackingRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := &Account{ + ID: 500, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(599)}, + }, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + // 不应返回 error(Skipped 不触发账号切换) + require.NoError(t, err, "should not return error") + require.NotNil(t, result, "result should not be nil") + require.NotNil(t, result.resp, "response should not be nil") + defer func() { _ = result.resp.Body.Close() }() + + // 状态码必须是 500(不透传原始状态码) + require.Equal(t, http.StatusInternalServerError, result.resp.StatusCode, + "skipped error should return 500, not %d", upstreamStatus) + + // 不调用 handleError + require.Equal(t, 0, handleErrorCount, + "handleError should NOT be called for skipped errors") + + // 不标记限流 + require.Equal(t, 0, repo.rateLimitedCalls, + "SetRateLimited should NOT be called for skipped errors") + + // 不停止调度 + require.Equal(t, 0, repo.setErrCalls, + "SetError should NOT be called for skipped errors") + + // 只调用一次上游(不重试) + require.Equal(t, 1, upstream.calls, + "should call upstream exactly once (no retry)") + }) + } +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index a8b69c22..9d7d025e 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -158,6 +158,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode int body []byte expectedHandled bool + expectedStatus int // expected outStatus expectedSwitchErr bool // expect *AntigravityAccountSwitchError handleErrorCalls int }{ @@ -171,6 +172,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: false, + expectedStatus: 500, // passthrough handleErrorCalls: 0, }, { @@ -187,6 +189,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, // not in custom codes body: []byte(`"error"`), expectedHandled: true, + expectedStatus: http.StatusInternalServerError, // skipped → 500 handleErrorCalls: 0, }, { @@ -203,6 +206,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 500, body: []byte(`"error"`), expectedHandled: true, + expectedStatus: 500, // matched → original status handleErrorCalls: 1, }, { @@ -225,6 +229,7 @@ func TestApplyErrorPolicy(t *testing.T) { statusCode: 503, body: []byte(`overloaded`), expectedHandled: true, + expectedStatus: 503, // temp_unscheduled → original status expectedSwitchErr: true, handleErrorCalls: 0, }, @@ -250,9 +255,10 @@ func TestApplyErrorPolicy(t *testing.T) { isStickySession: true, } - handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + handled, outStatus, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.expectedStatus, outStatus, "outStatus mismatch") require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") if tt.expectedSwitchErr { diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index dd58c183..d7108c8d 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) } + +func TestStripBetaToken(t *testing.T) { + tests := []struct { + name string + header string + token string + want string + }{ + { + name: "token in middle", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at start", + header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at end", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token not present", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "empty header", + header: "", + token: "context-1m-2025-08-07", + want: "", + }, + { + name: "with spaces", + header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "only token", + header: "context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBetaToken(tt.header, tt.token) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20" + drop := map[string]struct{}{"context-1m-2025-08-07": {}} + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index b4b93ace..09fda60e 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -87,7 +87,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index c039f030..743dd738 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -101,9 +101,9 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { } } - // thinking: {type: "enabled"} + // thinking: {type: "enabled" | "adaptive"} if rawThinking, ok := req["thinking"].(map[string]any); ok { - if t, ok := rawThinking["type"].(string); ok && t == "enabled" { + if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") { parsed.ThinkingEnabled = true } } @@ -161,9 +161,9 @@ func parseIntegralNumber(raw any) (int, bool) { // Returns filtered body or original body if filtering fails (fail-safe) // This prevents 400 errors from invalid thinking block signatures // -// Strategy: -// - When thinking.type != "enabled": Remove all thinking blocks -// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures +// 策略: +// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块 +// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块(避免 400) // (blocks with missing/empty/dummy signatures that would cause 400 errors) func FilterThinkingBlocks(body []byte) []byte { return filterThinkingBlocksInternal(body, false) @@ -489,9 +489,9 @@ func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte { } // filterThinkingBlocksInternal removes invalid thinking blocks from request -// Strategy: -// - When thinking.type != "enabled": Remove all thinking blocks -// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures +// 策略: +// - 当 thinking.type 不是 "enabled"/"adaptive":移除所有 thinking 相关块 +// - 当 thinking.type 是 "enabled"/"adaptive":仅移除缺失/无效 signature 的 thinking 块 func filterThinkingBlocksInternal(body []byte, _ bool) []byte { // Fast path: if body doesn't contain "thinking", skip parsing if !bytes.Contains(body, []byte(`"type":"thinking"`)) && @@ -511,7 +511,7 @@ func filterThinkingBlocksInternal(body []byte, _ bool) []byte { // Check if thinking is enabled thinkingEnabled := false if thinking, ok := req["thinking"].(map[string]any); ok { - if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + if thinkType, ok := thinking["type"].(string); ok && (thinkType == "enabled" || thinkType == "adaptive") { thinkingEnabled = true } } diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index cef41c91..5b85e752 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -29,6 +29,14 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { require.True(t, parsed.ThinkingEnabled) } +func TestParseGatewayRequest_ThinkingAdaptiveEnabled(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"},"messages":[{"content":"hi"}]}`) + parsed, err := ParseGatewayRequest(body, "") + require.NoError(t, err) + require.Equal(t, "claude-sonnet-4-5", parsed.Model) + require.True(t, parsed.ThinkingEnabled) +} + func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) parsed, err := ParseGatewayRequest(body, "") @@ -209,6 +217,16 @@ func TestFilterThinkingBlocks(t *testing.T) { input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`, shouldFilter: true, }, + { + name: "does not filter signed thinking blocks when thinking adaptive", + input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"ok","signature":"sig_real_123"},{"type":"text","text":"B"}]}]}`, + shouldFilter: false, + }, + { + name: "filters unsigned thinking blocks when thinking adaptive", + input: `{"thinking":{"type":"adaptive"},"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"internal","signature":""},{"type":"text","text":"B"}]}]}`, + shouldFilter: true, + }, { name: "handles no thinking blocks", input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4e723232..4d1dbad0 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -243,6 +243,12 @@ var ( } ) +// systemBlockFilterPrefixes 需要从 system 中过滤的文本前缀列表 +// OAuth/SetupToken 账号转发时,匹配这些前缀的 system 元素会被移除 +var systemBlockFilterPrefixes = []string{ + "x-anthropic-billing-header", +} + // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") @@ -343,6 +349,8 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) + CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) } // ForwardResult 转发结果 @@ -362,15 +370,31 @@ type ForwardResult struct { // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { - StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) } +// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 +// 由 handler 层在同账号重试全部用尽、切换账号时调用。 +func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { + if failoverErr == nil || !failoverErr.RetryableOnSameAccount { + return + } + // 根据状态码选择封禁策略 + switch failoverErr.StatusCode { + case http.StatusBadRequest: + tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]") + case http.StatusBadGateway: + tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]") + } +} + // GatewayService handles API gateway operations type GatewayService struct { accountRepo AccountRepository @@ -1683,6 +1707,17 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 +// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, +// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 +func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool { + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true) + if err != nil { + return false + } + return len(accounts) == 1 +} + func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { if account == nil { return false @@ -2673,6 +2708,60 @@ func hasClaudeCodePrefix(text string) bool { return false } +// matchesFilterPrefix 检查文本是否匹配任一过滤前缀 +func matchesFilterPrefix(text string) bool { + for _, prefix := range systemBlockFilterPrefixes { + if strings.HasPrefix(text, prefix) { + return true + } + } + return false +} + +// filterSystemBlocksByPrefix 从 body 的 system 中移除文本匹配 systemBlockFilterPrefixes 前缀的元素 +// 直接从 body 解析 system,不依赖外部传入的 parsed.System(因为前置步骤可能已修改 body 中的 system) +func filterSystemBlocksByPrefix(body []byte) []byte { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body + } + + switch { + case sys.Type == gjson.String: + if matchesFilterPrefix(sys.Str) { + result, err := sjson.DeleteBytes(body, "system") + if err != nil { + return body + } + return result + } + case sys.IsArray(): + var parsed []any + if err := json.Unmarshal([]byte(sys.Raw), &parsed); err != nil { + return body + } + filtered := make([]any, 0, len(parsed)) + changed := false + for _, item := range parsed { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && matchesFilterPrefix(text) { + changed = true + continue + } + } + filtered = append(filtered, item) + } + if changed { + result, err := sjson.SetBytes(body, "system", filtered) + if err != nil { + return body + } + return result + } + } + return body +} + // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { @@ -2952,6 +3041,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } + // OAuth/SetupToken 账号:移除黑名单前缀匹配的 system 元素(如客户端注入的计费元数据) + // 放在 inject/normalize 之后,确保不会被覆盖 + if account.IsOAuth() { + body = filterSystemBlocksByPrefix(body) + } + // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) @@ -3458,12 +3553,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := map[string]struct{}{claude.BetaClaudeCode: {}} + drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}} req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -3538,7 +3633,8 @@ func requestNeedsBetaFeatures(body []byte) bool { if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { return true } - if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") { return true } return false @@ -3616,6 +3712,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str return strings.Join(out, ",") } +// stripBetaToken removes a single beta token from a comma-separated header value. +// It short-circuits when the token is not present to avoid unnecessary allocations. +func stripBetaToken(header, token string) string { + if !strings.Contains(header, token) { + return header + } + out := make([]string, 0, 8) + for _, p := range strings.Split(header, ",") { + p = strings.TrimSpace(p) + if p == "" || p == token { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -4180,6 +4293,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + rewriteCacheCreationJSON(u, overrideTarget) + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + rewriteCacheCreationJSON(u, overrideTarget) + } + } + } + if needModelReplace { if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { @@ -4307,6 +4437,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { usage.InputTokens = msgStart.Message.Usage.InputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } } // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) @@ -4335,6 +4473,66 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { if msgDelta.Usage.CacheReadInputTokens > 0 { usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens } + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return + } + v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64) + v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64) + total := v5m + v1h + if total == 0 { + return + } + switch target { + case "1h": + ccObj["ephemeral_1h_input_tokens"] = total + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + ccObj["ephemeral_5m_input_tokens"] = total + ccObj["ephemeral_1h_input_tokens"] = float64(0) } } @@ -4355,6 +4553,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + // 兼容 Kimi cached_tokens → cache_read_input_tokens if response.Usage.CacheReadInputTokens == 0 { cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() @@ -4366,6 +4572,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) @@ -4442,6 +4662,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { @@ -4472,10 +4699,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { // Token 计费 tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) @@ -4509,6 +4738,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4523,6 +4754,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4623,6 +4855,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * result.Usage.InputTokens = 0 } + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { @@ -4653,10 +4892,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) @@ -4690,6 +4931,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4704,6 +4947,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -5009,7 +5253,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + drop := map[string]struct{}{claude.BetaContext1M: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -5019,7 +5264,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", beta) + req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index d77f6f92..f3abd1dc 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -770,6 +770,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex break } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -839,7 +847,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if upstreamReqID == "" { upstreamReqID = resp.Header.Get("x-goog-request-id") } - return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + return nil, s.writeGeminiMappedError(c, account, http.StatusInternalServerError, upstreamReqID, respBody) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) upstreamReqID := resp.Header.Get(requestIDHeader) @@ -872,6 +880,37 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex // ErrorPolicyNone → 原有逻辑 s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: true} + } + } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1176,6 +1215,14 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries: "+safeErr) } + // 错误策略优先:匹配则跳过重试直接处理。 + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() @@ -1283,7 +1330,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if contentType == "" { contentType = "application/json" } - c.Data(resp.StatusCode, contentType, respBody) + c.Data(http.StatusInternalServerError, contentType, respBody) return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) @@ -1314,6 +1361,34 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. // ErrorPolicyNone → 原有逻辑 s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 精确匹配服务端配置类 400 错误,触发 failover + 临时封禁 + if resp.StatusCode == http.StatusBadRequest { + msg400 := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if isGoogleProjectConfigError(msg400) { + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody))) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + log.Printf("[Gemini] status=400 google_config_error failover=true upstream_message=%q account=%d", upstreamMsg, account.ID) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody, RetryableOnSameAccount: true} + } + } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) @@ -1425,6 +1500,26 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } +// checkErrorPolicyInLoop 在重试循环内预检查错误策略。 +// 返回 true 表示策略已匹配(调用者应 break),resp 已重建可直接使用。 +// 返回 false 表示 ErrorPolicyNone,resp 已重建,调用者继续走重试逻辑。 +func (s *GeminiMessagesCompatService) checkErrorPolicyInLoop( + ctx context.Context, account *Account, resp *http.Response, +) (matched bool, rebuilt *http.Response) { + if resp.StatusCode < 400 || s.rateLimitService == nil { + return false, resp + } + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + rebuilt = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(body)), + } + policy := s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, body) + return policy != ErrorPolicyNone, rebuilt +} + func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *Account, statusCode int) bool { switch statusCode { case 429, 500, 502, 503, 504, 529: @@ -2568,11 +2663,12 @@ func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { prompt, _ := asInt(usageMeta["promptTokenCount"]) cand, _ := asInt(usageMeta["candidatesTokenCount"]) cached, _ := asInt(usageMeta["cachedContentTokenCount"]) + thoughts, _ := asInt(usageMeta["thoughtsTokenCount"]) // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ InputTokens: prompt - cached, - OutputTokens: cand, + OutputTokens: cand + thoughts, CacheReadInputTokens: cached, } } @@ -2597,6 +2693,10 @@ func asInt(v any) (int, bool) { } func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { + // 遵守自定义错误码策略:未命中则跳过所有限流处理 + if !account.ShouldHandleErrorCode(statusCode) { + return + } if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) return diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index f31b40ec..5bc26973 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "strings" "testing" + + "github.com/stretchr/testify/require" ) // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 @@ -203,3 +205,70 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s) } } + +func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) { + tests := []struct { + name string + resp map[string]any + wantInput int + wantOutput int + wantCacheRead int + wantNil bool + }{ + { + name: "with thoughtsTokenCount", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + "thoughtsTokenCount": float64(50), + }, + }, + wantInput: 100, + wantOutput: 70, + }, + { + name: "with thoughtsTokenCount and cache", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + "cachedContentTokenCount": float64(30), + "thoughtsTokenCount": float64(50), + }, + }, + wantInput: 70, + wantOutput: 70, + wantCacheRead: 30, + }, + { + name: "without thoughtsTokenCount (old model)", + resp: map[string]any{ + "usageMetadata": map[string]any{ + "promptTokenCount": float64(100), + "candidatesTokenCount": float64(20), + }, + }, + wantInput: 100, + wantOutput: 20, + }, + { + name: "no usageMetadata", + resp: map[string]any{}, + wantNil: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := extractGeminiUsage(tt.resp) + if tt.wantNil { + require.Nil(t, usage) + return + } + require.NotNil(t, usage) + require.Equal(t, tt.wantInput, usage.InputTokens) + require.Equal(t, tt.wantOutput, usage.OutputTokens) + require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens) + }) + } +} diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 080352ba..6b1fcecc 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -74,7 +74,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index cea81693..a57f0f99 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -112,13 +112,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.Modified = true } - if _, ok := reqBody["max_output_tokens"]; ok { - delete(reqBody, "max_output_tokens") - result.Modified = true - } - if _, ok := reqBody["max_completion_tokens"]; ok { - delete(reqBody, "max_completion_tokens") - result.Modified = true + // Strip parameters unsupported by codex models via the Responses API. + for _, key := range []string{ + "max_output_tokens", + "max_completion_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + } { + if _, ok := reqBody[key]; ok { + delete(reqBody, key) + result.Modified = true + } } if normalizeCodexTools(reqBody) { diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index f6541d08..92b37e73 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "") + }, platformFilter, "", "", "", 0) if err != nil { return nil, err } diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 96bcc9fe..3514df79 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -20,6 +20,10 @@ const ( // retry the specific upstream attempt (not just the client request). // This value is sanitized+trimmed before being persisted. OpsUpstreamRequestBodyKey = "ops_upstream_request_body" + + // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 + // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 + OpsSkipPassthroughKey = "ops_skip_passthrough" ) func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { @@ -103,6 +107,37 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { evCopy := ev existing = append(existing, &evCopy) c.Set(OpsUpstreamErrorsKey, existing) + + checkSkipMonitoringForUpstreamEvent(c, &evCopy) +} + +// checkSkipMonitoringForUpstreamEvent checks whether the upstream error event +// matches a passthrough rule with skip_monitoring=true and, if so, sets the +// OpsSkipPassthroughKey on the context. This ensures intermediate retry / +// failover errors (which never go through the final applyErrorPassthroughRule +// path) can still suppress ops_error_logs recording. +func checkSkipMonitoringForUpstreamEvent(c *gin.Context, ev *OpsUpstreamErrorEvent) { + if ev.UpstreamStatusCode == 0 { + return + } + + svc := getBoundErrorPassthroughService(c) + if svc == nil { + return + } + + // Use the best available body representation for keyword matching. + // Even when body is empty, MatchRule can still match rules that only + // specify ErrorCodes (no Keywords), so we always call it. + body := ev.Detail + if body == "" { + body = ev.Message + } + + rule := svc.MatchRule(ev.Platform, ev.UpstreamStatusCode, []byte(body)) + if rule != nil && rule.SkipMonitoring { + c.Set(OpsSkipPassthroughKey, true) + } } func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string { diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index d8db0d67..a3a94189 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -27,14 +27,15 @@ var ( // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -45,14 +46,15 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { - InputCostPerToken *float64 `json:"input_cost_per_token"` - OutputCostPerToken *float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage *float64 `json:"output_cost_per_image"` + InputCostPerToken *float64 `json:"input_cost_per_token"` + OutputCostPerToken *float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage *float64 `json:"output_cost_per_image"` } // PricingService 动态价格服务 @@ -318,6 +320,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } + if entry.CacheCreationInputTokenCostAbove1hr != nil { + pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr + } if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 63732dee..b1d767fc 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 2. 尝试从响应头解析重置时间(Anthropic) + // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 + if result := calculateAnthropic429ResetTime(headers); result != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + + // 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推 + windowEnd := result.resetAt + if result.fiveHourReset != nil { + windowEnd = *result.fiveHourReset + } + windowStart := windowEnd.Add(-5 * time.Hour) + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { + slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) + } + + slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second)) + return + } + + // 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容) resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") - // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) + // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) if resetTimestamp == "" { switch account.Platform { case PlatformOpenAI: @@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +// anthropic429Result holds the parsed Anthropic 429 rate-limit information. +type anthropic429Result struct { + resetAt time.Time // The correct reset time to use for SetRateLimited + fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available +} + +// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers +// to determine which window (5h or 7d) actually triggered the 429. +// +// Headers used: +// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold +// - anthropic-ratelimit-unified-5h-reset +// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold +// - anthropic-ratelimit-unified-7d-reset +// +// Returns nil when the per-window headers are absent (caller should fall back to +// the aggregated anthropic-ratelimit-unified-reset header). +func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result { + reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset") + reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset") + + if reset5hStr == "" && reset7dStr == "" { + return nil + } + + var reset5h, reset7d *time.Time + if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset5h = &t + } + if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset7d = &t + } + + is5hExceeded := isAnthropicWindowExceeded(headers, "5h") + is7dExceeded := isAnthropicWindowExceeded(headers, "7d") + + slog.Info("anthropic_429_window_analysis", + "is_5h_exceeded", is5hExceeded, + "is_7d_exceeded", is7dExceeded, + "reset_5h", reset5hStr, + "reset_7d", reset7dStr, + ) + + // Select the correct reset time based on which window(s) are exceeded. + var chosen *time.Time + switch { + case is5hExceeded && is7dExceeded: + // Both exceeded → prefer 7d (longer cooldown), fall back to 5h + chosen = reset7d + if chosen == nil { + chosen = reset5h + } + case is5hExceeded: + chosen = reset5h + case is7dExceeded: + chosen = reset7d + default: + // Neither flag clearly exceeded — pick the sooner reset as best guess + chosen = pickSooner(reset5h, reset7d) + } + + if chosen == nil { + return nil + } + return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h} +} + +// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window +// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers. +func isAnthropicWindowExceeded(headers http.Header, window string) bool { + prefix := "anthropic-ratelimit-unified-" + window + "-" + + // Check surpassed-threshold first (most explicit signal) + if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") { + return true + } + + // Fall back to utilization >= 1.0 + if utilStr := headers.Get(prefix + "utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 { + // Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0 + return true + } + } + + return false +} + +// pickSooner returns whichever of the two time pointers is earlier. +// If only one is non-nil, it is returned. If both are nil, returns nil. +func pickSooner(a, b *time.Time) *time.Time { + switch { + case a != nil && b != nil: + if a.Before(*b) { + return a + } + return b + case a != nil: + return a + default: + return b + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // @@ -623,6 +750,10 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } + // 同时清除模型级别限流 + if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { + slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err) + } return nil } diff --git a/backend/internal/service/ratelimit_service_anthropic_test.go b/backend/internal/service/ratelimit_service_anthropic_test.go new file mode 100644 index 00000000..eaeaf30e --- /dev/null +++ b/backend/internal/service/ratelimit_service_anthropic_test.go @@ -0,0 +1,202 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) + + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + // fiveHourReset should still be populated for session window calculation + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) +} + +func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + if result != nil { + t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) { + result := calculateAnthropic429ResetTime(http.Header{}) + if result != nil { + t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + if result.fiveHourReset != nil { + t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset) + } +} + +func TestIsAnthropicWindowExceeded(t *testing.T) { + tests := []struct { + name string + headers http.Header + window string + expected bool + }{ + { + name: "utilization above 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"), + window: "5h", + expected: true, + }, + { + name: "utilization exactly 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"), + window: "5h", + expected: true, + }, + { + name: "utilization below 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"), + window: "5h", + expected: false, + }, + { + name: "surpassed-threshold true", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold True (case insensitive)", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold false", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"), + window: "7d", + expected: false, + }, + { + name: "no headers", + headers: http.Header{}, + window: "5h", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isAnthropicWindowExceeded(tc.headers, tc.window) + if got != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, got) + } + }) + } +} + +// assertAnthropicResult is a test helper that verifies the result is non-nil and +// has the expected resetAt unix timestamp. +func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) { + t.Helper() + if result == nil { + t.Fatal("expected non-nil result") + return // unreachable, but satisfies staticcheck SA5011 + } + want := time.Unix(wantUnix, 0) + if !result.resetAt.Equal(want) { + t.Errorf("expected resetAt=%v, got %v", want, result.resetAt) + } +} + +func makeHeader(key, value string) http.Header { + h := http.Header{} + h.Set(key, value) + return h +} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index a9721d7f..aea19b35 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -26,8 +26,8 @@ type UsageLog struct { CacheCreationTokens int CacheReadTokens int - CacheCreation5mTokens int - CacheCreation1hTokens int + CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"` + CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"` InputCost float64 OutputCost float64 @@ -46,6 +46,9 @@ type UsageLog struct { UserAgent *string IPAddress *string + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + CacheTTLOverridden bool + // 图片生成字段 ImageCount int ImageSize *string diff --git a/backend/migrations/053_add_skip_monitoring_to_error_passthrough.sql b/backend/migrations/053_add_skip_monitoring_to_error_passthrough.sql new file mode 100644 index 00000000..71dbf181 --- /dev/null +++ b/backend/migrations/053_add_skip_monitoring_to_error_passthrough.sql @@ -0,0 +1,4 @@ +-- Add skip_monitoring field to error_passthrough_rules table +-- When true, errors matching this rule will not be recorded in ops_error_logs +ALTER TABLE error_passthrough_rules +ADD COLUMN IF NOT EXISTS skip_monitoring BOOLEAN NOT NULL DEFAULT false; diff --git a/backend/migrations/054_drop_legacy_cache_columns.sql b/backend/migrations/054_drop_legacy_cache_columns.sql new file mode 100644 index 00000000..ac73cd28 --- /dev/null +++ b/backend/migrations/054_drop_legacy_cache_columns.sql @@ -0,0 +1,14 @@ +-- Drop legacy cache token columns that lack the underscore separator. +-- These were created by GORM's automatic snake_case conversion: +-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect) +-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect) +-- +-- The canonical columns are: +-- cache_creation_5m_tokens (defined in 001_init.sql) +-- cache_creation_1h_tokens (defined in 001_init.sql) +-- +-- Migration 009 already copied data from legacy → canonical columns. +-- This migration drops the legacy columns to avoid confusion. + +ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens; +ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens; diff --git a/backend/migrations/055_add_cache_ttl_overridden.sql b/backend/migrations/055_add_cache_ttl_overridden.sql new file mode 100644 index 00000000..0d42fcf7 --- /dev/null +++ b/backend/migrations/055_add_cache_ttl_overridden.sql @@ -0,0 +1,2 @@ +-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 62629371..256f712f 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -120,6 +120,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-sub2api} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - POSTGRES_DB=${POSTGRES_DB:-sub2api} + - PGDATA=/var/lib/postgresql/data - TZ=${TZ:-Asia/Shanghai} networks: - sub2api-network diff --git a/frontend/package.json b/frontend/package.json index 325eba60..1b380b17 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -17,7 +17,7 @@ "dependencies": { "@lobehub/icons": "^4.0.2", "@vueuse/core": "^10.7.0", - "axios": "^1.6.2", + "axios": "^1.13.5", "chart.js": "^4.4.1", "dompurify": "^3.3.1", "driver.js": "^1.4.0", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 9af2d7af..37c384b4 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -15,8 +15,8 @@ importers: specifier: ^10.7.0 version: 10.11.1(vue@3.5.26(typescript@5.6.3)) axios: - specifier: ^1.6.2 - version: 1.13.2 + specifier: ^1.13.5 + version: 1.13.5 chart.js: specifier: ^4.4.1 version: 4.5.1 @@ -1257,56 +1257,67 @@ packages: resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.54.0': resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.54.0': resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.54.0': resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.54.0': resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-gnu@4.54.0': resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-gnu@4.54.0': resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.54.0': resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.54.0': resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.54.0': resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.54.0': resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openharmony-arm64@4.54.0': resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} @@ -1805,8 +1816,8 @@ packages: peerDependencies: postcss: ^8.1.0 - axios@1.13.2: - resolution: {integrity: sha512-VPk9ebNqPcy5lRGuSlKx752IlDatOjT9paPlm8A7yOuW2Fbvp4X3JznJtT4f0GzGLLiWE9W8onz51SqLYwzGaA==} + axios@1.13.5: + resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==} babel-plugin-macros@3.1.0: resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} @@ -6387,7 +6398,7 @@ snapshots: postcss: 8.5.6 postcss-value-parser: 4.2.0 - axios@1.13.2: + axios@1.13.5: dependencies: follow-redirects: 1.15.11 form-data: 4.0.5 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4cb1a6f2..e1299595 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -32,6 +32,7 @@ export async function list( platform?: string type?: string status?: string + group?: string search?: string }, options?: { diff --git a/frontend/src/api/admin/antigravity.ts b/frontend/src/api/admin/antigravity.ts index 0392da6f..779fa9c1 100644 --- a/frontend/src/api/admin/antigravity.ts +++ b/frontend/src/api/admin/antigravity.ts @@ -53,4 +53,18 @@ export async function exchangeCode( return data } -export default { generateAuthUrl, exchangeCode } +export async function refreshAntigravityToken( + refreshToken: string, + proxyId?: number | null +): Promise { + const payload: Record = { refresh_token: refreshToken } + if (proxyId) payload.proxy_id = proxyId + + const { data } = await apiClient.post( + '/admin/antigravity/oauth/refresh-token', + payload + ) + return data +} + +export default { generateAuthUrl, exchangeCode, refreshAntigravityToken } diff --git a/frontend/src/api/admin/errorPassthrough.ts b/frontend/src/api/admin/errorPassthrough.ts index 4c545ad5..e27c5be6 100644 --- a/frontend/src/api/admin/errorPassthrough.ts +++ b/frontend/src/api/admin/errorPassthrough.ts @@ -21,6 +21,7 @@ export interface ErrorPassthroughRule { response_code: number | null passthrough_body: boolean custom_message: string | null + skip_monitoring: boolean description: string | null created_at: string updated_at: string @@ -41,6 +42,7 @@ export interface CreateRuleRequest { response_code?: number | null passthrough_body?: boolean custom_message?: string | null + skip_monitoring?: boolean description?: string | null } @@ -59,6 +61,7 @@ export interface UpdateRuleRequest { response_code?: number | null passthrough_body?: boolean custom_message?: string | null + skip_monitoring?: boolean description?: string | null } diff --git a/frontend/src/components/account/AccountGroupsCell.vue b/frontend/src/components/account/AccountGroupsCell.vue index 512383a5..37771275 100644 --- a/frontend/src/components/account/AccountGroupsCell.vue +++ b/frontend/src/components/account/AccountGroupsCell.vue @@ -41,7 +41,7 @@ >
- {{ t('admin.accounts.allGroups', { count: groups.length }) }} + {{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
- {{ t('admin.accounts.types.upstream') }} - {{ t('admin.accounts.types.upstreamDesc') }} + API Key + {{ t('admin.accounts.types.antigravityApikey') }}
@@ -681,7 +681,7 @@ type="text" required class="input" - placeholder="https://s.konstants.xyz" + placeholder="https://cloudcode-pa.googleapis.com" />

{{ t('admin.accounts.upstream.baseUrlHint') }}

@@ -816,8 +816,8 @@ - -
+ +
{{ t('admin.accounts.gemini.tier.aiStudioHint') }}

- +
@@ -1527,6 +1527,46 @@
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }} +

+
+ +
+
+ + +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }} +

+
+
@@ -1647,12 +1687,12 @@ :show-proxy-warning="form.platform !== 'openai' && !!form.proxy_id" :allow-multiple="form.platform === 'anthropic'" :show-cookie-option="form.platform === 'anthropic'" - :show-refresh-token-option="form.platform === 'openai'" + :show-refresh-token-option="form.platform === 'openai' || form.platform === 'antigravity'" :platform="form.platform" :show-project-id="geminiOAuthType === 'code_assist'" @generate-url="handleGenerateUrl" @cookie-auth="handleCookieAuth" - @validate-refresh-token="handleOpenAIValidateRT" + @validate-refresh-token="handleValidateRefreshToken" />
@@ -2146,6 +2186,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') @@ -2597,6 +2639,8 @@ const resetForm = () => { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' upstreamApiKey.value = '' @@ -2802,6 +2846,14 @@ const handleGenerateUrl = async () => { } } +const handleValidateRefreshToken = (rt: string) => { + if (form.platform === 'openai') { + handleOpenAIValidateRT(rt) + } else if (form.platform === 'antigravity') { + handleAntigravityValidateRT(rt) + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput @@ -2950,6 +3002,95 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { } } +// Antigravity 手动 RT 批量验证和创建 +const handleAntigravityValidateRT = async (refreshTokenInput: string) => { + if (!refreshTokenInput.trim()) return + + // Parse multiple refresh tokens (one per line) + const refreshTokens = refreshTokenInput + .split('\n') + .map((rt) => rt.trim()) + .filter((rt) => rt) + + if (refreshTokens.length === 0) { + antigravityOAuth.error.value = t('admin.accounts.oauth.antigravity.pleaseEnterRefreshToken') + return + } + + antigravityOAuth.loading.value = true + antigravityOAuth.error.value = '' + + let successCount = 0 + let failedCount = 0 + const errors: string[] = [] + + try { + for (let i = 0; i < refreshTokens.length; i++) { + try { + const tokenInfo = await antigravityOAuth.validateRefreshToken( + refreshTokens[i], + form.proxy_id + ) + if (!tokenInfo) { + failedCount++ + errors.push(`#${i + 1}: ${antigravityOAuth.error.value || 'Validation failed'}`) + antigravityOAuth.error.value = '' + continue + } + + const credentials = antigravityOAuth.buildCredentials(tokenInfo) + + // Generate account name with index for batch + const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + + // Note: Antigravity doesn't have buildExtraInfo, so we pass empty extra or rely on credentials + await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'antigravity', + type: 'oauth', + credentials, + extra: {}, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + successCount++ + } catch (error: any) { + failedCount++ + const errMsg = error.response?.data?.detail || error.message || 'Unknown error' + errors.push(`#${i + 1}: ${errMsg}`) + } + } + + // Show results + if (successCount > 0 && failedCount === 0) { + appStore.showSuccess( + refreshTokens.length > 1 + ? t('admin.accounts.oauth.batchSuccess', { count: successCount }) + : t('admin.accounts.accountCreated') + ) + emit('created') + handleClose() + } else if (successCount > 0 && failedCount > 0) { + appStore.showWarning( + t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) + ) + antigravityOAuth.error.value = errors.join('\n') + emit('created') + } else { + antigravityOAuth.error.value = errors.join('\n') + appStore.showError(t('admin.accounts.oauth.batchFailed')) + } + } finally { + antigravityOAuth.loading.value = false + } +} + // Gemini OAuth 授权码兑换 const handleGeminiExchange = async (authCode: string) => { if (!authCode.trim() || !geminiOAuth.sessionId.value) return @@ -3077,6 +3218,12 @@ const handleAnthropicExchange = async (authCode: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const credentials = { ...tokenInfo, ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) @@ -3170,6 +3317,12 @@ const handleCookieAuth = async (sessionKey: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name // Merge interceptWarmupRequests into credentials diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 986bd297..ed243276 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -39,7 +39,9 @@ ? 'https://api.openai.com' : account.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' - : 'https://api.anthropic.com' + : account.platform === 'antigravity' + ? 'https://cloudcode-pa.googleapis.com' + : 'https://api.anthropic.com' " />

{{ baseUrlHint }}

@@ -55,14 +57,16 @@ ? 'sk-proj-...' : account.platform === 'gemini' ? 'AIza...' - : 'sk-ant-...' + : account.platform === 'antigravity' + ? 'sk-...' + : 'sk-ant-...' " />

{{ t('admin.accounts.leaveEmptyToKeep') }}

- -
+ +
@@ -372,7 +376,7 @@ v-model="editBaseUrl" type="text" class="input" - placeholder="https://s.konstants.xyz" + placeholder="https://cloudcode-pa.googleapis.com" />

{{ t('admin.accounts.upstream.baseUrlHint') }}

@@ -900,6 +904,46 @@
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }} +

+
+ +
+
+ + +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }} +

+
+
@@ -1098,6 +1142,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // Computed: current preset mappings based on platform const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) @@ -1485,6 +1531,8 @@ function loadQuotaControlSettings(account: Account) { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -1513,6 +1561,12 @@ function loadQuotaControlSettings(account: Account) { if (account.session_id_masking_enabled === true) { sessionIdMaskingEnabled.value = true } + + // Load cache TTL override setting + if (account.cache_ttl_override_enabled === true) { + cacheTTLOverrideEnabled.value = true + cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' + } } function formatTempUnschedKeywords(value: unknown) { @@ -1719,6 +1773,15 @@ const handleSubmit = async () => { delete newExtra.session_id_masking_enabled } + // Cache TTL override setting + if (cacheTTLOverrideEnabled.value) { + newExtra.cache_ttl_override_enabled = true + newExtra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } else { + delete newExtra.cache_ttl_override_enabled + delete newExtra.cache_ttl_override_target + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 78f488c1..22e179ba 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -45,19 +45,19 @@ class="text-blue-600 focus:ring-blue-500" /> {{ - t('admin.accounts.oauth.openai.refreshTokenAuth') + t(getOAuthKey('refreshTokenAuth')) }}
- +

- {{ t('admin.accounts.oauth.openai.refreshTokenDesc') }} + {{ t(getOAuthKey('refreshTokenDesc')) }}

@@ -78,7 +78,7 @@ v-model="refreshTokenInput" rows="3" class="input w-full resize-y font-mono text-sm" - :placeholder="t('admin.accounts.oauth.openai.refreshTokenPlaceholder')" + :placeholder="t(getOAuthKey('refreshTokenPlaceholder'))" >

{{ loading - ? t('admin.accounts.oauth.openai.validating') - : t('admin.accounts.oauth.openai.validateAndCreate') + ? t(getOAuthKey('validating')) + : t(getOAuthKey('validateAndCreate')) }}

diff --git a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue index b93319c5..2ed6ded3 100644 --- a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue +++ b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue @@ -148,6 +148,16 @@ {{ rule.passthrough_body ? t('admin.errorPassthrough.passthrough') : t('admin.errorPassthrough.custom') }}
+
+ + + {{ t('admin.errorPassthrough.skipMonitoring') }} + +
@@ -366,6 +376,19 @@ + +
+ + + {{ t('admin.errorPassthrough.form.skipMonitoring') }} + +
+

{{ t('admin.errorPassthrough.form.skipMonitoringHint') }}

+
{ form.response_code = null form.passthrough_body = true form.custom_message = null + form.skip_monitoring = false form.description = null errorCodesInput.value = '' keywordsInput.value = '' @@ -520,6 +545,7 @@ const handleEdit = (rule: ErrorPassthroughRule) => { form.response_code = rule.response_code form.passthrough_body = rule.passthrough_body form.custom_message = rule.custom_message + form.skip_monitoring = rule.skip_monitoring form.description = rule.description errorCodesInput.value = rule.error_codes.join(', ') keywordsInput.value = rule.keywords.join('\n') @@ -575,6 +601,7 @@ const handleSubmit = async () => { response_code: form.passthrough_code ? null : form.response_code, passthrough_body: form.passthrough_body, custom_message: form.passthrough_body ? null : form.custom_message, + skip_monitoring: form.skip_monitoring, description: form.description?.trim() || null } diff --git a/frontend/src/components/admin/account/AccountActionMenu.vue b/frontend/src/components/admin/account/AccountActionMenu.vue index bb753faa..2325f4b4 100644 --- a/frontend/src/components/admin/account/AccountActionMenu.vue +++ b/frontend/src/components/admin/account/AccountActionMenu.vue @@ -53,7 +53,19 @@ import type { Account } from '@/types' const props = defineProps<{ show: boolean; account: Account | null; position: { top: number; left: number } | null }>() const emit = defineEmits(['close', 'test', 'stats', 'reauth', 'refresh-token', 'reset-status', 'clear-rate-limit']) const { t } = useI18n() -const isRateLimited = computed(() => props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) +const isRateLimited = computed(() => { + if (props.account?.rate_limit_reset_at && new Date(props.account.rate_limit_reset_at) > new Date()) { + return true + } + const modelLimits = (props.account?.extra as Record | undefined)?.model_rate_limits as + | Record + | undefined + if (modelLimits) { + const now = new Date() + return Object.values(modelLimits).some(info => new Date(info.rate_limit_reset_at) > now) + } + return false +}) const isOverloaded = computed(() => props.account?.overload_until && new Date(props.account.overload_until) > new Date()) const handleKeydown = (event: KeyboardEvent) => { diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 47ceedd7..b37f0359 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -10,16 +10,21 @@
diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index fbb1942a..a6420f1c 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -70,6 +70,8 @@
{{ formatCacheTokens(row.cache_creation_tokens) }} + 1h + R
@@ -157,9 +159,36 @@ {{ t('admin.usage.outputTokens') }} {{ tokenTooltipData.output_tokens.toLocaleString() }} -
- {{ t('admin.usage.cacheCreationTokens') }} - {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+ + + +
+ {{ t('admin.usage.cacheCreationTokens') }} + {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+
+
+ + {{ t('usage.cacheTtlOverriddenLabel') }} + R-{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? '5m' : '1H' }} + + {{ tokenTooltipData.cache_creation_1h_tokens > 0 ? t('usage.cacheTtlOverridden1h') : t('usage.cacheTtlOverridden5m') }}
{{ t('admin.usage.cacheReadTokens') }} diff --git a/frontend/src/components/common/StatCard.vue b/frontend/src/components/common/StatCard.vue index 203a2fa8..d7c40a2e 100644 --- a/frontend/src/components/common/StatCard.vue +++ b/frontend/src/components/common/StatCard.vue @@ -6,7 +6,7 @@

{{ title }}

-

{{ formattedValue }}

+

{{ formattedValue }}

- Logo + Logo
@@ -167,6 +167,7 @@ const isDark = ref(document.documentElement.classList.contains('dark')) const siteName = computed(() => appStore.siteName) const siteLogo = computed(() => appStore.siteLogo) const siteVersion = computed(() => appStore.siteVersion) +const settingsLoaded = computed(() => appStore.publicSettingsLoaded) // SVG Icon Components const DashboardIcon = { diff --git a/frontend/src/components/layout/AuthLayout.vue b/frontend/src/components/layout/AuthLayout.vue index bd20b3c4..e2d9d2aa 100644 --- a/frontend/src/components/layout/AuthLayout.vue +++ b/frontend/src/components/layout/AuthLayout.vue @@ -29,17 +29,19 @@
-
- Logo -
-

- {{ siteName }} -

-

- {{ siteSubtitle }} -

+
@@ -61,25 +63,21 @@ diff --git a/frontend/src/composables/useAntigravityOAuth.ts b/frontend/src/composables/useAntigravityOAuth.ts index 2c1a4cfe..cf60fd09 100644 --- a/frontend/src/composables/useAntigravityOAuth.ts +++ b/frontend/src/composables/useAntigravityOAuth.ts @@ -83,6 +83,35 @@ export function useAntigravityOAuth() { } } + const validateRefreshToken = async ( + refreshToken: string, + proxyId?: number | null + ): Promise => { + if (!refreshToken.trim()) { + error.value = t('admin.accounts.oauth.antigravity.pleaseEnterRefreshToken') + return null + } + + loading.value = true + error.value = '' + + try { + const tokenInfo = await adminAPI.antigravity.refreshAntigravityToken( + refreshToken.trim(), + proxyId + ) + return tokenInfo as AntigravityTokenInfo + } catch (err: any) { + error.value = + err.response?.data?.detail || t('admin.accounts.oauth.antigravity.failedToValidateRT') + // Don't show global error toast for batch validation to avoid spamming + // appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + const buildCredentials = (tokenInfo: AntigravityTokenInfo): Record => { let expiresAt: string | undefined if (typeof tokenInfo.expires_at === 'number' && Number.isFinite(tokenInfo.expires_at)) { @@ -110,6 +139,7 @@ export function useAntigravityOAuth() { resetState, generateAuthUrl, exchangeAuthCode, + validateRefreshToken, buildCredentials } } diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 0ef80431..98c668f0 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -39,6 +39,7 @@ export const claudeModels = [ 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001', 'claude-opus-4-5-20251101', 'claude-opus-4-6', + 'claude-sonnet-4-6', 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' ] @@ -233,6 +234,7 @@ export const allModels = allModelsList.map(m => ({ value: m, label: m })) const anthropicPresetMappings = [ { label: 'Sonnet 4', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-20250514', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-5-20250929', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, + { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' }, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 77b56233..ec53ed5b 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -576,6 +576,10 @@ export default { description: 'View and analyze your API usage history', costDetails: 'Cost Breakdown', tokenDetails: 'Token Breakdown', + cacheTtlOverriddenHint: 'Cache TTL Override enabled', + cacheTtlOverriddenLabel: 'TTL Override', + cacheTtlOverridden5m: 'Billed as 5m', + cacheTtlOverridden1h: 'Billed as 1h', totalRequests: 'Total Requests', totalTokens: 'Total Tokens', totalCost: 'Total Cost', @@ -841,7 +845,7 @@ export default { createUser: 'Create User', editUser: 'Edit User', deleteUser: 'Delete User', - searchUsers: 'Search users...', + searchUsers: 'Search by email, username, notes, or API key...', allRoles: 'All Roles', allStatus: 'All Status', admin: 'Admin', @@ -1335,6 +1339,7 @@ export default { allPlatforms: 'All Platforms', allTypes: 'All Types', allStatus: 'All Status', + allGroups: 'All Groups', oauthType: 'OAuth', setupToken: 'Setup Token', apiKey: 'API Key', @@ -1344,7 +1349,7 @@ export default { schedulableEnabled: 'Scheduling enabled', schedulableDisabled: 'Scheduling disabled', failedToToggleSchedulable: 'Failed to toggle scheduling status', - allGroups: '{count} groups total', + groupCountTotal: '{count} groups total', platforms: { anthropic: 'Anthropic', claude: 'Claude', @@ -1359,6 +1364,7 @@ export default { googleOauth: 'Google OAuth', codeAssist: 'Code Assist', antigravityOauth: 'Antigravity OAuth', + antigravityApikey: 'Connect via Base URL + API Key', upstream: 'Upstream', upstreamDesc: 'Connect via Base URL + API Key' }, @@ -1593,6 +1599,12 @@ export default { sessionIdMasking: { label: 'Session ID Masking', hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session' + }, + cacheTTLOverride: { + label: 'Cache TTL Override', + hint: 'Force all cache creation tokens to be billed as the selected TTL tier (5m or 1h)', + target: 'Target TTL', + targetHint: 'Select the TTL tier for billing' } }, expired: 'Expired', @@ -1625,7 +1637,7 @@ export default { // Upstream type upstream: { baseUrl: 'Upstream Base URL', - baseUrlHint: 'The address of the upstream Antigravity service, e.g., https://s.konstants.xyz', + baseUrlHint: 'The address of the upstream Antigravity service, e.g., https://cloudcode-pa.googleapis.com', apiKey: 'Upstream API Key', apiKeyHint: 'API Key for the upstream service', pleaseEnterBaseUrl: 'Please enter upstream Base URL', @@ -1773,13 +1785,20 @@ export default { authCode: 'Authorization URL or Code', authCodePlaceholder: 'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value', - authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect', - failedToGenerateUrl: 'Failed to generate Antigravity auth URL', - missingExchangeParams: 'Missing code, session ID, or state', - failedToExchangeCode: 'Failed to exchange Antigravity auth code' - } - }, - // Gemini specific (platform-wide) + authCodeHint: 'You can copy the entire URL or just the code parameter value, the system will auto-detect', + failedToGenerateUrl: 'Failed to generate Antigravity auth URL', + missingExchangeParams: 'Missing code, session ID, or state', + failedToExchangeCode: 'Failed to exchange Antigravity auth code', + // Refresh Token auth + refreshTokenAuth: 'Manual RT', + refreshTokenDesc: 'Enter your existing Antigravity Refresh Token. Supports batch input (one per line). The system will automatically validate and create accounts.', + refreshTokenPlaceholder: 'Paste your Antigravity Refresh Token...\nSupports multiple tokens, one per line', + validating: 'Validating...', + validateAndCreate: 'Validate & Create', + pleaseEnterRefreshToken: 'Please enter Refresh Token', + failedToValidateRT: 'Failed to validate Refresh Token' + } + }, // Gemini specific (platform-wide) gemini: { helpButton: 'Help', helpDialog: { @@ -2128,7 +2147,7 @@ export default { title: 'Redeem Code Management', description: 'Generate and manage redeem codes', generateCodes: 'Generate Codes', - searchCodes: 'Search codes...', + searchCodes: 'Search codes or email...', allTypes: 'All Types', allStatus: 'All Status', balance: 'Balance', @@ -2351,6 +2370,8 @@ export default { inputTokens: 'Input Tokens', outputTokens: 'Output Tokens', cacheCreationTokens: 'Cache Creation Tokens', + cacheCreation5mTokens: 'Cache Write', + cacheCreation1hTokens: 'Cache Write', cacheReadTokens: 'Cache Read Tokens', failedToLoad: 'Failed to load usage records', billingType: 'Billing Type', @@ -3352,6 +3373,7 @@ export default { custom: 'Custom', code: 'Code', body: 'Body', + skipMonitoring: 'Skip Monitoring', // Columns columns: { @@ -3396,6 +3418,8 @@ export default { passthroughBody: 'Passthrough upstream error message', customMessage: 'Custom error message', customMessagePlaceholder: 'Error message to return to client...', + skipMonitoring: 'Skip monitoring', + skipMonitoringHint: 'When enabled, errors matching this rule will not be recorded in ops monitoring', enabled: 'Enable this rule' }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index b70d1d5d..1dda2692 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -582,6 +582,10 @@ export default { description: '查看和分析您的 API 使用历史', costDetails: '成本明细', tokenDetails: 'Token 明细', + cacheTtlOverriddenHint: '缓存 TTL Override 已启用', + cacheTtlOverriddenLabel: 'TTL 替换', + cacheTtlOverridden5m: '按 5m 计费', + cacheTtlOverridden1h: '按 1h 计费', totalRequests: '总请求数', totalTokens: '总 Token', totalCost: '总消费', @@ -865,8 +869,8 @@ export default { editUser: '编辑用户', deleteUser: '删除用户', deleteConfirmMessage: "确定要删除用户 '{email}' 吗?此操作无法撤销。", - searchPlaceholder: '搜索用户邮箱或用户名、备注、支持模糊查询...', - searchUsers: '搜索用户邮箱或用户名、备注、支持模糊查询', + searchPlaceholder: '邮箱/用户名/备注/API Key 模糊搜索...', + searchUsers: '邮箱/用户名/备注/API Key 模糊搜索', roleFilter: '角色筛选', allRoles: '全部角色', allStatus: '全部状态', @@ -1426,6 +1430,7 @@ export default { allPlatforms: '全部平台', allTypes: '全部类型', allStatus: '全部状态', + allGroups: '全部分组', oauthType: 'OAuth', // Schedulable toggle schedulable: '参与调度', @@ -1433,7 +1438,7 @@ export default { schedulableEnabled: '调度已开启', schedulableDisabled: '调度已关闭', failedToToggleSchedulable: '切换调度状态失败', - allGroups: '共 {count} 个分组', + groupCountTotal: '共 {count} 个分组', columns: { name: '名称', platformType: '平台/类型', @@ -1493,6 +1498,7 @@ export default { googleOauth: 'Google OAuth', codeAssist: 'Code Assist', antigravityOauth: 'Antigravity OAuth', + antigravityApikey: '通过 Base URL + API Key 连接', upstream: '对接上游', upstreamDesc: '通过 Base URL + API Key 连接上游', api_key: 'API Key', @@ -1739,6 +1745,12 @@ export default { sessionIdMasking: { label: '会话 ID 伪装', hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话' + }, + cacheTTLOverride: { + label: '缓存 TTL 强制替换', + hint: '将所有缓存创建 token 强制按指定的 TTL 类型(5分钟或1小时)计费', + target: '目标 TTL', + targetHint: '选择计费使用的 TTL 类型' } }, expired: '已过期', @@ -1771,7 +1783,7 @@ export default { // Upstream type upstream: { baseUrl: '上游 Base URL', - baseUrlHint: '上游 Antigravity 服务的地址,例如:https://s.konstants.xyz', + baseUrlHint: '上游 Antigravity 服务的地址,例如:https://cloudcode-pa.googleapis.com', apiKey: '上游 API Key', apiKeyHint: '上游服务的 API Key', pleaseEnterBaseUrl: '请输入上游 Base URL', @@ -1912,7 +1924,15 @@ export default { authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别', failedToGenerateUrl: '生成 Antigravity 授权链接失败', missingExchangeParams: '缺少 code / session_id / state', - failedToExchangeCode: 'Antigravity 授权码兑换失败' + failedToExchangeCode: 'Antigravity 授权码兑换失败', + // Refresh Token auth + refreshTokenAuth: '手动输入 RT', + refreshTokenDesc: '输入您已有的 Antigravity Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', + refreshTokenPlaceholder: '粘贴您的 Antigravity Refresh Token...\n支持多个,每行一个', + validating: '验证中...', + validateAndCreate: '验证并创建账号', + pleaseEnterRefreshToken: '请输入 Refresh Token', + failedToValidateRT: '验证 Refresh Token 失败' } }, // Gemini specific (platform-wide) @@ -2291,7 +2311,7 @@ export default { allStatus: '全部状态', unused: '未使用', used: '已使用', - searchCodes: '搜索兑换码...', + searchCodes: '搜索兑换码或邮箱...', exportCsv: '导出 CSV', deleteAllUnused: '删除全部未使用', deleteCodeConfirm: '确定要删除此兑换码吗?此操作无法撤销。', @@ -2517,6 +2537,8 @@ export default { inputTokens: '输入 Token', outputTokens: '输出 Token', cacheCreationTokens: '缓存创建 Token', + cacheCreation5mTokens: '缓存创建', + cacheCreation1hTokens: '缓存创建', cacheReadTokens: '缓存读取 Token', failedToLoad: '加载使用记录失败', billingType: '计费类型', @@ -3526,6 +3548,7 @@ export default { custom: '自定义', code: '状态码', body: '消息体', + skipMonitoring: '跳过监控', // Columns columns: { @@ -3570,6 +3593,8 @@ export default { passthroughBody: '透传上游错误信息', customMessage: '自定义错误信息', customMessagePlaceholder: '返回给客户端的错误信息...', + skipMonitoring: '跳过运维监控记录', + skipMonitoringHint: '开启后,匹配此规则的错误不会被记录到运维监控中', enabled: '启用此规则' }, diff --git a/frontend/src/style.css b/frontend/src/style.css index c1ee8ea5..25631aaf 100644 --- a/frontend/src/style.css +++ b/frontend/src/style.css @@ -243,7 +243,7 @@ } .stat-value { - @apply text-2xl font-bold text-gray-900 dark:text-white; + @apply text-2xl font-bold text-gray-900 dark:text-white truncate; } .stat-label { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index a2f12ff3..7f781a73 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -614,6 +614,10 @@ export interface Account { // 启用后将在15分钟内固定 metadata.user_id 中的 session ID session_id_masking_enabled?: boolean | null + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + cache_ttl_override_enabled?: boolean | null + cache_ttl_override_target?: string | null + // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 @@ -827,6 +831,9 @@ export interface UsageLog { // User-Agent user_agent: string | null + // Cache TTL Override + cache_ttl_overridden: boolean + created_at: string user?: User diff --git a/frontend/src/utils/url.ts b/frontend/src/utils/url.ts index a4dc0351..57c6487f 100644 --- a/frontend/src/utils/url.ts +++ b/frontend/src/utils/url.ts @@ -6,6 +6,7 @@ */ type SanitizeOptions = { allowRelative?: boolean + allowDataUrl?: boolean } export function sanitizeUrl(value: string, options: SanitizeOptions = {}): string { @@ -18,6 +19,11 @@ export function sanitizeUrl(value: string, options: SanitizeOptions = {}): strin return trimmed } + // 允许 data:image/ 开头的 data URL(仅限图片类型) + if (options.allowDataUrl && trimmed.startsWith('data:image/')) { + return trimmed + } + // 只接受绝对 URL,不使用 base URL 来避免相对路径被解析为当前域名 // 检查是否以 http:// 或 https:// 开头 if (!trimmed.match(/^https?:\/\//i)) { diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 456fc8d9..fb2e6c1a 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -6,6 +6,7 @@ !hiddenColumns.has(key) const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange, handlePageSizeChange } = useTableLoader({ fetchFn: adminAPI.accounts.list, - initialParams: { platform: '', type: '', status: '', search: '' } + initialParams: { platform: '', type: '', status: '', group: '', search: '' } }) const isAnyModalOpen = computed(() => { diff --git a/frontend/src/views/admin/RedeemView.vue b/frontend/src/views/admin/RedeemView.vue index d5ba9d3e..17e612c5 100644 --- a/frontend/src/views/admin/RedeemView.vue +++ b/frontend/src/views/admin/RedeemView.vue @@ -117,9 +117,9 @@ -