diff --git a/backend/ent/group.go b/backend/ent/group.go index dca64cec..4a31442a 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -51,6 +51,10 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // 是否仅允许 Claude Code 客户端 + ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` + // 非 Claude Code 请求降级使用的分组 ID + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldIsExclusive: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldClaudeCodeOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) + } else if value.Valid { + _m.ClaudeCodeOnly = value.Bool + } + case group.FieldFallbackGroupID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i]) + } else if value.Valid { + _m.FallbackGroupID = new(int64) + *_m.FallbackGroupID = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -440,6 +457,14 @@ func (_m *Group) String() string { builder.WriteString("image_price_4k=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("claude_code_only=") + builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) + builder.WriteString(", ") + if v := _m.FallbackGroupID; v != nil { + builder.WriteString("fallback_group_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 1c5ed343..c4317f00 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,10 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. + FieldClaudeCodeOnly = "claude_code_only" + // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. + FieldFallbackGroupID = "fallback_group_id" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -141,6 +145,8 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldClaudeCodeOnly, + FieldFallbackGroupID, } var ( @@ -196,6 +202,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. + DefaultClaudeCodeOnly bool ) // OrderOption defines the ordering options for the Group queries. @@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// ByClaudeCodeOnly orders the results by the claude_code_only field. +func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() +} + +// ByFallbackGroupID orders the results by the fallback_group_id field. +func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 7bce1fe6..fb2f942f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. +func ClaudeCodeOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) +} + +// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ. +func FallbackGroupID(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. +func ClaudeCodeOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) +} + +// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field. +func ClaudeCodeOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v)) +} + +// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field. +func FallbackGroupIDEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field. +func FallbackGroupIDNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field. +func FallbackGroupIDIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...)) +} + +// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field. +func FallbackGroupIDNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...)) +} + +// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field. +func FallbackGroupIDGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field. +func FallbackGroupIDGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field. +func FallbackGroupIDLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field. +func FallbackGroupIDLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field. +func FallbackGroupIDIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID)) +} + +// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field. +func FallbackGroupIDNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 6a928af6..59229402 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { + _c.mutation.SetClaudeCodeOnly(v) + return _c +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetClaudeCodeOnly(*v) + } + return _c +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupID(v) + return _c +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupID(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { + v := group.DefaultClaudeCodeOnly + _c.mutation.SetClaudeCodeOnly(v) + } return nil } @@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { + return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} + } return nil } @@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + _node.ClaudeCodeOnly = value + } + if value, ok := _c.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + _node.FallbackGroupID = &value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { + u.Set(group.FieldClaudeCodeOnly, v) + return u +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert { + u.SetExcluded(group.FieldClaudeCodeOnly) + return u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupID, v) + return u +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupID) + return u +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupID, v) + return u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupID) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetClaudeCodeOnly(v) + }) +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateClaudeCodeOnly() + }) +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupID(v) + }) +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupID(v) + }) +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupID() + }) +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupID() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetClaudeCodeOnly(v) + }) +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateClaudeCodeOnly() + }) +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupID(v) + }) +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupID(v) + }) +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupID() + }) +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupID() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 43555ce2..1a6f15ec 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { + _u.mutation.SetClaudeCodeOnly(v) + return _u +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetClaudeCodeOnly(*v) + } + return _u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupID() + _u.mutation.SetFallbackGroupID(v) + return _u +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupID(*v) + } + return _u +} + +// AddFallbackGroupID adds value to the "fallback_group_id" field. +func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupID(v) + return _u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { + _u.mutation.ClearFallbackGroupID() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupID(); ok { + _spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDCleared() { + _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { + _u.mutation.SetClaudeCodeOnly(v) + return _u +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetClaudeCodeOnly(*v) + } + return _u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupID() + _u.mutation.SetFallbackGroupID(v) + return _u +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupID(*v) + } + return _u +} + +// AddFallbackGroupID adds value to the "fallback_group_id" field. +func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupID(v) + return _u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupID() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupID(); ok { + _spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDCleared() { + _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index e48201f3..13081e31 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -221,6 +221,8 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "claude_code_only", Type: field.TypeBool, Default: false}, + {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index a809e858..4e01e12b 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3590,6 +3590,9 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (m *GroupMutation) SetClaudeCodeOnly(b bool) { + m.claude_code_only = &b +} + +// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. +func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { + v := m.claude_code_only + if v == nil { + return + } + return *v, true +} + +// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) + } + return oldValue.ClaudeCodeOnly, nil +} + +// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. +func (m *GroupMutation) ResetClaudeCodeOnly() { + m.claude_code_only = nil +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (m *GroupMutation) SetFallbackGroupID(i int64) { + m.fallback_group_id = &i + m.addfallback_group_id = nil +} + +// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. +func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { + v := m.fallback_group_id + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + } + return oldValue.FallbackGroupID, nil +} + +// AddFallbackGroupID adds i to the "fallback_group_id" field. +func (m *GroupMutation) AddFallbackGroupID(i int64) { + if m.addfallback_group_id != nil { + *m.addfallback_group_id += i + } else { + m.addfallback_group_id = &i + } +} + +// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { + v := m.addfallback_group_id + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (m *GroupMutation) ClearFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + m.clearedFields[group.FieldFallbackGroupID] = struct{}{} +} + +// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupID] + return ok +} + +// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. +func (m *GroupMutation) ResetFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + delete(m.clearedFields, group.FieldFallbackGroupID) +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 17) + fields := make([]string, 0, 19) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.claude_code_only != nil { + fields = append(fields, group.FieldClaudeCodeOnly) + } + if m.fallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldClaudeCodeOnly: + return m.ClaudeCodeOnly() + case group.FieldFallbackGroupID: + return m.FallbackGroupID() } return nil, false } @@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldClaudeCodeOnly: + return m.OldClaudeCodeOnly(ctx) + case group.FieldFallbackGroupID: + return m.OldFallbackGroupID(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldClaudeCodeOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaudeCodeOnly(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupID(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addfallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldFallbackGroupID: + return m.AddedFallbackGroupID() } return nil, false } @@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupID(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldFallbackGroupID) { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldFallbackGroupID: + m.ClearFallbackGroupID() + return nil } return fmt.Errorf("unknown Group nullable field %s", name) } @@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldClaudeCodeOnly: + m.ResetClaudeCodeOnly() + return nil + case group.FieldFallbackGroupID: + m.ResetFallbackGroupID() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6ccfc6d2..fb1c948c 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -270,6 +270,10 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. + groupDescClaudeCodeOnly := groupFields[14].Descriptor() + // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. + group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) proxyMixin := schema.Proxy{}.Mixin() proxyMixinHooks1 := proxyMixin[1].Hooks() proxy.Hooks[0] = proxyMixinHooks1[0] diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 7b5f77b1..d38925b1 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + + // Claude Code 客户端限制 (added by migration 029) + field.Bool("claude_code_only"). + Default(false). + Comment("是否仅允许 Claude Code 客户端"), + field.Int64("fallback_group_id"). + Optional(). + Nillable(). + Comment("非 Claude Code 请求降级使用的分组 ID"), } } @@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge { edge.From("allowed_users", User.Type). Ref("allowed_groups"). Through("user_allowed_groups", UserAllowedGroup.Type), + // 注意:fallback_group_id 直接作为字段使用,不定义 edge + // 这样允许多个分组指向同一个降级分组(M2O 关系) } } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 182d26d0..e2f2408b 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -37,6 +37,8 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` } // UpdateGroupRequest represents update group request @@ -55,6 +57,8 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` } // List handles listing all groups with pagination @@ -150,6 +154,8 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, }) if err != nil { response.ErrorFrom(c, err) @@ -188,6 +194,8 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 79394a50..9a672064 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, AccountCount: g.AccountCount, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 140c020b..03f7080b 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -52,6 +52,10 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Claude Code 客户端限制 + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 2d8ff957..48a827f3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqModel := parsedReq.Model reqStream := parsedReq.Stream + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + // 验证 model 必填 if reqModel == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") @@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } @@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } @@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + // 验证 model 必填 if parsedReq.Model == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 5de519c7..0393f954 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -2,6 +2,7 @@ package handler import ( "context" + "encoding/json" "fmt" "math/rand" "net/http" @@ -13,6 +14,26 @@ import ( "github.com/gin-gonic/gin" ) +// claudeCodeValidator is a singleton validator for Claude Code client detection +var claudeCodeValidator = service.NewClaudeCodeValidator() + +// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 +// 返回更新后的 context +func SetClaudeCodeClientContext(c *gin.Context, body []byte) { + // 解析请求体为 map + var bodyMap map[string]any + if len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + + // 验证是否为 Claude Code 客户端 + isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + + // 更新 request context + ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) + c.Request = c.Request.WithContext(ctx) +} + // 并发槽位等待相关常量 // // 性能优化说明: diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index fc8c7cd6..0cbe44f2 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) + + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionKey := sessionHash if sessionHash != "" { @@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusTooManyRequests, err.Error()) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index f76a9851..70131417 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 8920ea69..3add78de 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -7,4 +7,6 @@ type Key string const ( // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 ForcePlatform Key = "ctx_force_platform" + // IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置 + IsClaudeCodeClient Key = "ctx_is_claude_code_client" ) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 4384bff5..f3b07616 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 4ed47e9b..40a9ad05 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } -func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { - key := stickySessionPrefix + sessionHash +// buildSessionKey 构建 session key,包含 groupID 实现分组隔离 +// 格式: sticky_session:{groupID}:{sessionHash} +func buildSessionKey(groupID int64, sessionHash string) string { + return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash) +} + +func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Get(ctx, key).Int64() } -func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { - key := stickySessionPrefix + sessionHash +func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Set(ctx, key, accountID, ttl).Err() } -func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { - key := stickySessionPrefix + sessionHash +func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Expire(ctx, key, ttl).Err() } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 729c1404..1fb4ae90 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetDefaultValidityDays(groupIn.DefaultValidityDays) + SetDefaultValidityDays(groupIn.DefaultValidityDays). + SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). + SetNillableFallbackGroupID(groupIn.FallbackGroupID) created, err := builder.Save(ctx) if err == nil { @@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group } func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { - updated, err := r.client.Group.UpdateOneID(groupIn.ID). + builder := r.client.Group.UpdateOneID(groupIn.ID). SetName(groupIn.Name). SetDescription(groupIn.Description). SetPlatform(groupIn.Platform). @@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). - Save(ctx) + SetClaudeCodeOnly(groupIn.ClaudeCodeOnly) + + // 处理 FallbackGroupID:nil 时清除,否则设置 + if groupIn.FallbackGroupID != nil { + builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID) + } else { + builder = builder.ClearFallbackGroupID() + } + + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 0f2cf998..d6283fbe 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -103,6 +103,8 @@ type CreateGroupInput struct { ImagePrice1K *float64 ImagePrice2K *float64 ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID } type UpdateGroupInput struct { @@ -120,6 +122,8 @@ type UpdateGroupInput struct { ImagePrice1K *float64 ImagePrice2K *float64 ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID } type CreateAccountInput struct { @@ -516,6 +520,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + // 校验降级分组 + if input.FallbackGroupID != nil { + if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { + return nil, err + } + } + group := &Group{ Name: input.Name, Description: input.Description, @@ -530,6 +541,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -553,6 +566,29 @@ func normalizePrice(price *float64) *float64 { return price } +// validateFallbackGroup 校验降级分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// fallbackGroupID: 降级分组 ID +func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error { + // 不能将自己设置为降级分组 + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as fallback group") + } + + // 检查降级分组是否存在 + fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + + // 降级分组不能启用 claude_code_only,否则会造成死循环 + if fallbackGroup.ClaudeCodeOnly { + return fmt.Errorf("fallback group cannot have claude_code_only enabled") + } + + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -603,6 +639,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + // Claude Code 客户端限制 + if input.ClaudeCodeOnly != nil { + group.ClaudeCodeOnly = *input.ClaudeCodeOnly + } + if input.FallbackGroupID != nil { + // 校验降级分组 + if *input.FallbackGroupID > 0 { + if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { + return nil, err + } + group.FallbackGroupID = input.FallbackGroupID + } else { + // 传入 0 或负数表示清除降级分组 + group.FallbackGroupID = nil + } + } + if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go new file mode 100644 index 00000000..ab86f1e8 --- /dev/null +++ b/backend/internal/service/claude_code_validator.go @@ -0,0 +1,265 @@ +package service + +import ( + "context" + "net/http" + "regexp" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端 +// 完全学习自 claude-relay-service 项目的验证逻辑 +type ClaudeCodeValidator struct{} + +var ( + // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) + claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) + + // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} + userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) + + // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) + systemPromptThreshold = 0.5 +) + +// Claude Code 官方 System Prompt 模板 +// 从 claude-relay-service/src/utils/contents.js 提取 +var claudeCodeSystemPrompts = []string{ + // claudeOtherSystemPrompt1 - Primary + "You are Claude Code, Anthropic's official CLI for Claude.", + + // claudeOtherSystemPrompt3 - Agent SDK + "You are a Claude agent, built on Anthropic's Claude Agent SDK.", + + // claudeOtherSystemPrompt4 - Compact Agent SDK + "You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.", + + // exploreAgentSystemPrompt + "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", + + // claudeOtherSystemPromptCompact - Compact (用于对话摘要) + "You are a helpful AI assistant tasked with summarizing conversations.", + + // claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分) + "You are an interactive CLI tool that helps users", +} + +// NewClaudeCodeValidator 创建验证器实例 +func NewClaudeCodeValidator() *ClaudeCodeValidator { + return &ClaudeCodeValidator{} +} + +// Validate 验证请求是否来自 Claude Code CLI +// 采用与 claude-relay-service 完全一致的验证策略: +// +// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x +// Step 2: 对于非 messages 路径,只要 UA 匹配就通过 +// Step 3: 对于 messages 路径,进行严格验证: +// - System prompt 相似度检查 +// - X-App header 检查 +// - anthropic-beta header 检查 +// - anthropic-version header 检查 +// - metadata.user_id 格式验证 +func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool { + // Step 1: User-Agent 检查 + ua := r.Header.Get("User-Agent") + if !claudeCodeUAPattern.MatchString(ua) { + return false + } + + // Step 2: 非 messages 路径,只要 UA 匹配就通过 + path := r.URL.Path + if !strings.Contains(path, "messages") { + return true + } + + // Step 3: messages 路径,进行严格验证 + + // 3.1 检查 system prompt 相似度 + if !v.hasClaudeCodeSystemPrompt(body) { + return false + } + + // 3.2 检查必需的 headers(值不为空即可) + xApp := r.Header.Get("X-App") + if xApp == "" { + return false + } + + anthropicBeta := r.Header.Get("anthropic-beta") + if anthropicBeta == "" { + return false + } + + anthropicVersion := r.Header.Get("anthropic-version") + if anthropicVersion == "" { + return false + } + + // 3.3 验证 metadata.user_id + if body == nil { + return false + } + + metadata, ok := body["metadata"].(map[string]any) + if !ok { + return false + } + + userID, ok := metadata["user_id"].(string) + if !ok || userID == "" { + return false + } + + if !userIDPattern.MatchString(userID) { + return false + } + + return true +} + +// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词 +// 使用字符串相似度匹配(Dice coefficient) +func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool { + if body == nil { + return false + } + + // 检查 model 字段 + if _, ok := body["model"].(string); !ok { + return false + } + + // 获取 system 字段 + systemEntries, ok := body["system"].([]any) + if !ok { + return false + } + + // 检查每个 system entry + for _, entry := range systemEntries { + entryMap, ok := entry.(map[string]any) + if !ok { + continue + } + + text, ok := entryMap["text"].(string) + if !ok || text == "" { + continue + } + + // 计算与所有模板的最佳相似度 + bestScore := v.bestSimilarityScore(text) + if bestScore >= systemPromptThreshold { + return true + } + } + + return false +} + +// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度 +func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 { + normalizedText := normalizePrompt(text) + bestScore := 0.0 + + for _, template := range claudeCodeSystemPrompts { + normalizedTemplate := normalizePrompt(template) + score := diceCoefficient(normalizedText, normalizedTemplate) + if score > bestScore { + bestScore = score + } + } + + return bestScore +} + +// normalizePrompt 标准化提示词文本(去除多余空白) +func normalizePrompt(text string) string { + // 将所有空白字符替换为单个空格,并去除首尾空白 + return strings.Join(strings.Fields(text), " ") +} + +// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient) +// 这是 string-similarity 库使用的算法 +// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|) +func diceCoefficient(a, b string) float64 { + if a == b { + return 1.0 + } + + if len(a) < 2 || len(b) < 2 { + return 0.0 + } + + // 生成 bigrams + bigramsA := getBigrams(a) + bigramsB := getBigrams(b) + + if len(bigramsA) == 0 || len(bigramsB) == 0 { + return 0.0 + } + + // 计算交集大小 + intersection := 0 + for bigram, countA := range bigramsA { + if countB, exists := bigramsB[bigram]; exists { + if countA < countB { + intersection += countA + } else { + intersection += countB + } + } + } + + // 计算总 bigram 数量 + totalA := 0 + for _, count := range bigramsA { + totalA += count + } + totalB := 0 + for _, count := range bigramsB { + totalB += count + } + + return float64(2*intersection) / float64(totalA+totalB) +} + +// getBigrams 获取字符串的所有 bigrams(相邻字符对) +func getBigrams(s string) map[string]int { + bigrams := make(map[string]int) + runes := []rune(strings.ToLower(s)) + + for i := 0; i < len(runes)-1; i++ { + bigram := string(runes[i : i+2]) + bigrams[bigram]++ + } + + return bigrams +} + +// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景) +func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool { + return claudeCodeUAPattern.MatchString(ua) +} + +// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词 +// 只要存在匹配的系统提示词就返回 true(用于宽松检测) +func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool { + return v.hasClaudeCodeSystemPrompt(body) +} + +// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识 +func IsClaudeCodeClient(ctx context.Context) bool { + if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok { + return v + } + return false +} + +// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中 +func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context { + return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 98c061d4..e73e9406 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -56,6 +56,9 @@ var ( } ) +// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 +var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{ // GatewayCache defines cache operations for gateway service type GatewayCache interface { - GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) - SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error - RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error + GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) + SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error + RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error +} + +// derefGroupID safely dereferences *int64 to int64, returning 0 if nil +func derefGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID } type AccountWaitPlan struct { @@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } // BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { +func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 || s.cache == nil { return nil } - return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { @@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return nil, fmt.Errorf("get group failed: %w", err) } platform = group.Platform + + // 检查 Claude Code 客户端限制 + if group.ClaudeCodeOnly { + isClaudeCode := IsClaudeCodeClient(ctx) + if !isClaudeCode { + // 非 Claude Code 客户端,检查是否有降级分组 + if group.FallbackGroupID != nil { + // 使用降级分组重新调度 + fallbackGroupID := *group.FallbackGroupID + return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs) + } + // 无降级分组,拒绝访问 + return nil, ErrClaudeCodeOnly + } + } } else { // 无分组时只使用原生 anthropic 平台 platform = PlatformAnthropic @@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID } } + + // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) + groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) + if err != nil { + return nil, err + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) if err != nil { @@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // ============ Layer 1: 粘性会话优先 ============ if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && s.isAccountInGroup(account, groupID) && @@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -506,7 +539,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil } } else { @@ -556,7 +589,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -584,7 +617,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -592,7 +625,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -619,6 +652,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { } } +// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 +// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: +// - 有降级分组:返回降级分组的 ID +// - 无降级分组:返回 ErrClaudeCodeOnly 错误 +func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) { + if groupID == nil { + return groupID, nil + } + + // 强制平台模式不检查 Claude Code 限制 + if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + return groupID, nil + } + + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } + + if !group.ClaudeCodeOnly { + return groupID, nil + } + + // 分组启用了 Claude Code 限制 + if IsClaudeCodeClient(ctx) { + return groupID, nil + } + + // 非 Claude Code 客户端,检查降级分组 + if group.FallbackGroupID != nil { + return group.FallbackGroupID, nil + } + + return nil, ErrClaudeCodeOnly +} + func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { @@ -738,13 +807,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil @@ -811,7 +880,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -827,14 +896,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil @@ -903,7 +972,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index fdf912d0..f2b5bafd 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co cacheKey := "gemini:" + sessionHash if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) @@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } if usable { - _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL) return account, nil } } @@ -217,7 +217,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) } return selected, nil diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 01b6b513..80d89074 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -22,6 +22,10 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Claude Code 客户端限制 + ClaudeCodeOnly bool + FallbackGroupID *int64 + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d744bfab..b0d34654 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { } // BindStickySession sets session -> account binding with standard TTL. -func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { +func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) } // SelectAccount selects an OpenAI account with sticky session support @@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 1. Check sticky session if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // Refresh sticky session TTL - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) return account, nil } } @@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. Set sticky session if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { stickyAccountID = accountID } } @@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, diff --git a/backend/migrations/029_add_group_claude_code_restriction.sql b/backend/migrations/029_add_group_claude_code_restriction.sql new file mode 100644 index 00000000..6185704d --- /dev/null +++ b/backend/migrations/029_add_group_claude_code_restriction.sql @@ -0,0 +1,21 @@ +-- 029_add_group_claude_code_restriction.sql +-- 添加分组级别的 Claude Code 客户端限制功能 + +-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE; + +-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only +ON groups(claude_code_only) WHERE deleted_at IS NULL; + +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id +ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组'; +COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID'; diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index f28048e2..c4cf6cc6 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -857,6 +857,15 @@ export default { imagePricing: { title: 'Image Generation Pricing', description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.' + }, + claudeCode: { + title: 'Claude Code Client Restriction', + tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.', + enabled: 'Claude Code Only', + disabled: 'Allow All Clients', + fallbackGroup: 'Fallback Group', + fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.', + noFallback: 'No Fallback (Reject)' } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a042c1dc..79ddf6cc 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -934,6 +934,15 @@ export default { imagePricing: { title: '图片生成计费', description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格' + }, + claudeCode: { + title: 'Claude Code 客户端限制', + tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。', + enabled: '仅限 Claude Code', + disabled: '允许所有客户端', + fallbackGroup: '降级分组', + fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝', + noFallback: '不降级(直接拒绝)' } }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 360d20c4..eaea24be 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -263,6 +263,9 @@ export interface Group { image_price_1k: number | null image_price_2k: number | null image_price_4k: number | null + // Claude Code 客户端限制 + claude_code_only: boolean + fallback_group_id: number | null account_count?: number created_at: string updated_at: string @@ -298,6 +301,15 @@ export interface CreateGroupRequest { platform?: GroupPlatform rate_multiplier?: number is_exclusive?: boolean + subscription_type?: SubscriptionType + daily_limit_usd?: number | null + weekly_limit_usd?: number | null + monthly_limit_usd?: number | null + image_price_1k?: number | null + image_price_2k?: number | null + image_price_4k?: number | null + claude_code_only?: boolean + fallback_group_id?: number | null } export interface UpdateGroupRequest { @@ -307,6 +319,15 @@ export interface UpdateGroupRequest { rate_multiplier?: number is_exclusive?: boolean status?: 'active' | 'inactive' + subscription_type?: SubscriptionType + daily_limit_usd?: number | null + weekly_limit_usd?: number | null + monthly_limit_usd?: number | null + image_price_1k?: number | null + image_price_2k?: number | null + image_price_4k?: number | null + claude_code_only?: boolean + fallback_group_id?: number | null } // ==================== Account & Proxy Types ==================== diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index b3664767..f7ef2339 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -403,6 +403,62 @@ + +
+
+ + +
+ +
+
+

+ {{ t('admin.groups.claudeCode.tooltip') }} +

+
+
+
+
+
+
+ + + {{ createForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }} + +
+ +
+ +