feat(groups): add Claude Code client restriction and session isolation
- Add claude_code_only field to restrict groups to Claude Code clients only - Add fallback_group_id for non-Claude Code requests to use alternate group - Implement ClaudeCodeValidator for User-Agent detection - Add group-level session binding isolation (groupID in Redis key) - Prevent cross-group sticky session pollution - Update frontend with Claude Code restriction controls
This commit is contained in:
@@ -51,6 +51,10 @@ type Group struct {
|
|||||||
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
|
||||||
// ImagePrice4k holds the value of the "image_price_4k" field.
|
// ImagePrice4k holds the value of the "image_price_4k" field.
|
||||||
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
|
||||||
|
// 是否仅允许 Claude Code 客户端
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
|
||||||
|
// 非 Claude Code 请求降级使用的分组 ID
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the GroupQuery when eager-loading is set.
|
// The values are being populated by the GroupQuery when eager-loading is set.
|
||||||
Edges GroupEdges `json:"edges"`
|
Edges GroupEdges `json:"edges"`
|
||||||
@@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case group.FieldIsExclusive:
|
case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case group.FieldID, group.FieldDefaultValidityDays:
|
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
@@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error {
|
|||||||
_m.ImagePrice4k = new(float64)
|
_m.ImagePrice4k = new(float64)
|
||||||
*_m.ImagePrice4k = value.Float64
|
*_m.ImagePrice4k = value.Float64
|
||||||
}
|
}
|
||||||
|
case group.FieldClaudeCodeOnly:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.ClaudeCodeOnly = value.Bool
|
||||||
|
}
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.FallbackGroupID = new(int64)
|
||||||
|
*_m.FallbackGroupID = value.Int64
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@@ -440,6 +457,14 @@ func (_m *Group) String() string {
|
|||||||
builder.WriteString("image_price_4k=")
|
builder.WriteString("image_price_4k=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
}
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("claude_code_only=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.FallbackGroupID; v != nil {
|
||||||
|
builder.WriteString("fallback_group_id=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ const (
|
|||||||
FieldImagePrice2k = "image_price_2k"
|
FieldImagePrice2k = "image_price_2k"
|
||||||
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
|
||||||
FieldImagePrice4k = "image_price_4k"
|
FieldImagePrice4k = "image_price_4k"
|
||||||
|
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
|
||||||
|
FieldClaudeCodeOnly = "claude_code_only"
|
||||||
|
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
|
||||||
|
FieldFallbackGroupID = "fallback_group_id"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@@ -141,6 +145,8 @@ var Columns = []string{
|
|||||||
FieldImagePrice1k,
|
FieldImagePrice1k,
|
||||||
FieldImagePrice2k,
|
FieldImagePrice2k,
|
||||||
FieldImagePrice4k,
|
FieldImagePrice4k,
|
||||||
|
FieldClaudeCodeOnly,
|
||||||
|
FieldFallbackGroupID,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -196,6 +202,8 @@ var (
|
|||||||
SubscriptionTypeValidator func(string) error
|
SubscriptionTypeValidator func(string) error
|
||||||
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
|
||||||
DefaultDefaultValidityDays int
|
DefaultDefaultValidityDays int
|
||||||
|
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
|
||||||
|
DefaultClaudeCodeOnly bool
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the Group queries.
|
// OrderOption defines the ordering options for the Group queries.
|
||||||
@@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByClaudeCodeOnly orders the results by the claude_code_only field.
|
||||||
|
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByFallbackGroupID orders the results by the fallback_group_id field.
|
||||||
|
func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group {
|
|||||||
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
|
||||||
|
func ClaudeCodeOnly(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ.
|
||||||
|
func FallbackGroupID(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.Group {
|
func CreatedAtEQ(v time.Time) predicate.Group {
|
||||||
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group {
|
|||||||
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
|
||||||
|
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field.
|
||||||
|
func ClaudeCodeOnlyNEQ(v bool) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDNEQ(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDNotIn(vs ...int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDGT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDGTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDLT(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDLTE(v int64) predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDIsNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field.
|
||||||
|
func FallbackGroupIDNotNil() predicate.Group {
|
||||||
|
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.Group {
|
func HasAPIKeys() predicate.Group {
|
||||||
return predicate.Group(func(s *sql.Selector) {
|
return predicate.Group(func(s *sql.Selector) {
|
||||||
|
|||||||
@@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
|
||||||
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetClaudeCodeOnly(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate {
|
||||||
|
_c.mutation.SetFallbackGroupID(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||||
|
func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetFallbackGroupID(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error {
|
|||||||
v := group.DefaultDefaultValidityDays
|
v := group.DefaultDefaultValidityDays
|
||||||
_c.mutation.SetDefaultValidityDays(v)
|
_c.mutation.SetDefaultValidityDays(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
|
v := group.DefaultClaudeCodeOnly
|
||||||
|
_c.mutation.SetClaudeCodeOnly(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error {
|
|||||||
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
|
||||||
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
|
||||||
|
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
|
||||||
_node.ImagePrice4k = &value
|
_node.ImagePrice4k = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
|
||||||
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
|
_node.ClaudeCodeOnly = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.FallbackGroupID(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
|
_node.FallbackGroupID = &value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
|
||||||
|
u.Set(group.FieldClaudeCodeOnly, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldClaudeCodeOnly)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert {
|
||||||
|
u.Set(group.FieldFallbackGroupID, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert {
|
||||||
|
u.SetExcluded(group.FieldFallbackGroupID)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert {
|
||||||
|
u.Add(group.FieldFallbackGroupID, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
|
||||||
|
u.SetNull(group.FieldFallbackGroupID)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetClaudeCodeOnly(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateClaudeCodeOnly()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetClaudeCodeOnly(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateClaudeCodeOnly()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.SetFallbackGroupID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds v to the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.AddFallbackGroupID(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create.
|
||||||
|
func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.UpdateFallbackGroupID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
|
||||||
|
return u.Update(func(s *GroupUpsert) {
|
||||||
|
s.ClearFallbackGroupID()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
|
||||||
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetClaudeCodeOnly(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.ResetFallbackGroupID()
|
||||||
|
_u.mutation.SetFallbackGroupID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate {
|
||||||
|
_u.mutation.AddFallbackGroupID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
|
||||||
|
_u.mutation.ClearFallbackGroupID()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.ImagePrice4kCleared() {
|
if _u.mutation.ImagePrice4kCleared() {
|
||||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
|
||||||
|
_u.mutation.SetClaudeCodeOnly(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetClaudeCodeOnly(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.ResetFallbackGroupID()
|
||||||
|
_u.mutation.SetFallbackGroupID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil.
|
||||||
|
func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetFallbackGroupID(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds value to the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne {
|
||||||
|
_u.mutation.AddFallbackGroupID(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
|
||||||
|
_u.mutation.ClearFallbackGroupID()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
|
|||||||
if _u.mutation.ImagePrice4kCleared() {
|
if _u.mutation.ImagePrice4kCleared() {
|
||||||
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
|
||||||
|
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.FallbackGroupID(); ok {
|
||||||
|
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedFallbackGroupID(); ok {
|
||||||
|
_spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.FallbackGroupIDCleared() {
|
||||||
|
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@@ -221,6 +221,8 @@ var (
|
|||||||
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
|
||||||
|
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
|
||||||
}
|
}
|
||||||
// GroupsTable holds the schema information for the "groups" table.
|
// GroupsTable holds the schema information for the "groups" table.
|
||||||
GroupsTable = &schema.Table{
|
GroupsTable = &schema.Table{
|
||||||
|
|||||||
@@ -3590,6 +3590,9 @@ type GroupMutation struct {
|
|||||||
addimage_price_2k *float64
|
addimage_price_2k *float64
|
||||||
image_price_4k *float64
|
image_price_4k *float64
|
||||||
addimage_price_4k *float64
|
addimage_price_4k *float64
|
||||||
|
claude_code_only *bool
|
||||||
|
fallback_group_id *int64
|
||||||
|
addfallback_group_id *int64
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() {
|
|||||||
delete(m.clearedFields, group.FieldImagePrice4k)
|
delete(m.clearedFields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeOnly sets the "claude_code_only" field.
|
||||||
|
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
|
||||||
|
m.claude_code_only = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation.
|
||||||
|
func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) {
|
||||||
|
v := m.claude_code_only
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.ClaudeCodeOnly, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field.
|
||||||
|
func (m *GroupMutation) ResetClaudeCodeOnly() {
|
||||||
|
m.claude_code_only = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFallbackGroupID sets the "fallback_group_id" field.
|
||||||
|
func (m *GroupMutation) SetFallbackGroupID(i int64) {
|
||||||
|
m.fallback_group_id = &i
|
||||||
|
m.addfallback_group_id = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) {
|
||||||
|
v := m.fallback_group_id
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity.
|
||||||
|
// If the Group object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldFallbackGroupID requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.FallbackGroupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFallbackGroupID adds i to the "fallback_group_id" field.
|
||||||
|
func (m *GroupMutation) AddFallbackGroupID(i int64) {
|
||||||
|
if m.addfallback_group_id != nil {
|
||||||
|
*m.addfallback_group_id += i
|
||||||
|
} else {
|
||||||
|
m.addfallback_group_id = &i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation.
|
||||||
|
func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) {
|
||||||
|
v := m.addfallback_group_id
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearFallbackGroupID clears the value of the "fallback_group_id" field.
|
||||||
|
func (m *GroupMutation) ClearFallbackGroupID() {
|
||||||
|
m.fallback_group_id = nil
|
||||||
|
m.addfallback_group_id = nil
|
||||||
|
m.clearedFields[group.FieldFallbackGroupID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation.
|
||||||
|
func (m *GroupMutation) FallbackGroupIDCleared() bool {
|
||||||
|
_, ok := m.clearedFields[group.FieldFallbackGroupID]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFallbackGroupID resets all changes to the "fallback_group_id" field.
|
||||||
|
func (m *GroupMutation) ResetFallbackGroupID() {
|
||||||
|
m.fallback_group_id = nil
|
||||||
|
m.addfallback_group_id = nil
|
||||||
|
delete(m.clearedFields, group.FieldFallbackGroupID)
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *GroupMutation) Fields() []string {
|
func (m *GroupMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 17)
|
fields := make([]string, 0, 19)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, group.FieldCreatedAt)
|
fields = append(fields, group.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string {
|
|||||||
if m.image_price_4k != nil {
|
if m.image_price_4k != nil {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
|
if m.claude_code_only != nil {
|
||||||
|
fields = append(fields, group.FieldClaudeCodeOnly)
|
||||||
|
}
|
||||||
|
if m.fallback_group_id != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.ImagePrice2k()
|
return m.ImagePrice2k()
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.ImagePrice4k()
|
return m.ImagePrice4k()
|
||||||
|
case group.FieldClaudeCodeOnly:
|
||||||
|
return m.ClaudeCodeOnly()
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
return m.FallbackGroupID()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
|
|||||||
return m.OldImagePrice2k(ctx)
|
return m.OldImagePrice2k(ctx)
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.OldImagePrice4k(ctx)
|
return m.OldImagePrice4k(ctx)
|
||||||
|
case group.FieldClaudeCodeOnly:
|
||||||
|
return m.OldClaudeCodeOnly(ctx)
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
return m.OldFallbackGroupID(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown Group field %s", name)
|
return nil, fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetImagePrice4k(v)
|
m.SetImagePrice4k(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldClaudeCodeOnly:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetClaudeCodeOnly(v)
|
||||||
|
return nil
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetFallbackGroupID(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string {
|
|||||||
if m.addimage_price_4k != nil {
|
if m.addimage_price_4k != nil {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
|
if m.addfallback_group_id != nil {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedImagePrice2k()
|
return m.AddedImagePrice2k()
|
||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
return m.AddedImagePrice4k()
|
return m.AddedImagePrice4k()
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
return m.AddedFallbackGroupID()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddImagePrice4k(v)
|
m.AddImagePrice4k(v)
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
v, ok := value.(int64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddFallbackGroupID(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group numeric field %s", name)
|
return fmt.Errorf("unknown Group numeric field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(group.FieldImagePrice4k) {
|
if m.FieldCleared(group.FieldImagePrice4k) {
|
||||||
fields = append(fields, group.FieldImagePrice4k)
|
fields = append(fields, group.FieldImagePrice4k)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(group.FieldFallbackGroupID) {
|
||||||
|
fields = append(fields, group.FieldFallbackGroupID)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error {
|
|||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
m.ClearImagePrice4k()
|
m.ClearImagePrice4k()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
m.ClearFallbackGroupID()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group nullable field %s", name)
|
return fmt.Errorf("unknown Group nullable field %s", name)
|
||||||
}
|
}
|
||||||
@@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error {
|
|||||||
case group.FieldImagePrice4k:
|
case group.FieldImagePrice4k:
|
||||||
m.ResetImagePrice4k()
|
m.ResetImagePrice4k()
|
||||||
return nil
|
return nil
|
||||||
|
case group.FieldClaudeCodeOnly:
|
||||||
|
m.ResetClaudeCodeOnly()
|
||||||
|
return nil
|
||||||
|
case group.FieldFallbackGroupID:
|
||||||
|
m.ResetFallbackGroupID()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown Group field %s", name)
|
return fmt.Errorf("unknown Group field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -270,6 +270,10 @@ func init() {
|
|||||||
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
groupDescDefaultValidityDays := groupFields[10].Descriptor()
|
||||||
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
|
||||||
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
|
||||||
|
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
|
||||||
|
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
|
||||||
|
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
|
||||||
|
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
|
||||||
proxyMixin := schema.Proxy{}.Mixin()
|
proxyMixin := schema.Proxy{}.Mixin()
|
||||||
proxyMixinHooks1 := proxyMixin[1].Hooks()
|
proxyMixinHooks1 := proxyMixin[1].Hooks()
|
||||||
proxy.Hooks[0] = proxyMixinHooks1[0]
|
proxy.Hooks[0] = proxyMixinHooks1[0]
|
||||||
|
|||||||
@@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field {
|
|||||||
Optional().
|
Optional().
|
||||||
Nillable().
|
Nillable().
|
||||||
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
|
||||||
|
|
||||||
|
// Claude Code 客户端限制 (added by migration 029)
|
||||||
|
field.Bool("claude_code_only").
|
||||||
|
Default(false).
|
||||||
|
Comment("是否仅允许 Claude Code 客户端"),
|
||||||
|
field.Int64("fallback_group_id").
|
||||||
|
Optional().
|
||||||
|
Nillable().
|
||||||
|
Comment("非 Claude Code 请求降级使用的分组 ID"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge {
|
|||||||
edge.From("allowed_users", User.Type).
|
edge.From("allowed_users", User.Type).
|
||||||
Ref("allowed_groups").
|
Ref("allowed_groups").
|
||||||
Through("user_allowed_groups", UserAllowedGroup.Type),
|
Through("user_allowed_groups", UserAllowedGroup.Type),
|
||||||
|
// 注意:fallback_group_id 直接作为字段使用,不定义 edge
|
||||||
|
// 这样允许多个分组指向同一个降级分组(M2O 关系)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ type CreateGroupRequest struct {
|
|||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroupRequest represents update group request
|
// UpdateGroupRequest represents update group request
|
||||||
@@ -55,6 +57,8 @@ type UpdateGroupRequest struct {
|
|||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// List handles listing all groups with pagination
|
// List handles listing all groups with pagination
|
||||||
@@ -150,6 +154,8 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
|||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@@ -188,6 +194,8 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
|||||||
ImagePrice1K: req.ImagePrice1K,
|
ImagePrice1K: req.ImagePrice1K,
|
||||||
ImagePrice2K: req.ImagePrice2K,
|
ImagePrice2K: req.ImagePrice2K,
|
||||||
ImagePrice4K: req.ImagePrice4K,
|
ImagePrice4K: req.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: req.FallbackGroupID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
|
|||||||
@@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
|||||||
ImagePrice1K: g.ImagePrice1K,
|
ImagePrice1K: g.ImagePrice1K,
|
||||||
ImagePrice2K: g.ImagePrice2K,
|
ImagePrice2K: g.ImagePrice2K,
|
||||||
ImagePrice4K: g.ImagePrice4K,
|
ImagePrice4K: g.ImagePrice4K,
|
||||||
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
AccountCount: g.AccountCount,
|
AccountCount: g.AccountCount,
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ type Group struct {
|
|||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
|
||||||
|
// Claude Code 客户端限制
|
||||||
|
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||||
|
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||||
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
reqModel := parsedReq.Model
|
reqModel := parsedReq.Model
|
||||||
reqStream := parsedReq.Stream
|
reqStream := parsedReq.Stream
|
||||||
|
|
||||||
|
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if reqModel == "" {
|
if reqModel == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
@@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if parsedReq.Model == "" {
|
if parsedReq.Model == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -13,6 +14,26 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// claudeCodeValidator is a singleton validator for Claude Code client detection
|
||||||
|
var claudeCodeValidator = service.NewClaudeCodeValidator()
|
||||||
|
|
||||||
|
// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中
|
||||||
|
// 返回更新后的 context
|
||||||
|
func SetClaudeCodeClientContext(c *gin.Context, body []byte) {
|
||||||
|
// 解析请求体为 map
|
||||||
|
var bodyMap map[string]any
|
||||||
|
if len(body) > 0 {
|
||||||
|
_ = json.Unmarshal(body, &bodyMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否为 Claude Code 客户端
|
||||||
|
isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap)
|
||||||
|
|
||||||
|
// 更新 request context
|
||||||
|
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
// 并发槽位等待相关常量
|
// 并发槽位等待相关常量
|
||||||
//
|
//
|
||||||
// 性能优化说明:
|
// 性能优化说明:
|
||||||
|
|||||||
@@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 3) select account (sticky session based on request body)
|
// 3) select account (sticky session based on request body)
|
||||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||||
|
|
||||||
|
// 设置 Claude Code 客户端标识到 context(用于分组限制检查)
|
||||||
|
SetClaudeCodeClientContext(c, body)
|
||||||
|
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||||
sessionKey := sessionHash
|
sessionKey := sessionHash
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
@@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
googleError(c, http.StatusTooManyRequests, err.Error())
|
googleError(c, http.StatusTooManyRequests, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,4 +7,6 @@ type Key string
|
|||||||
const (
|
const (
|
||||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||||
ForcePlatform Key = "ctx_force_platform"
|
ForcePlatform Key = "ctx_force_platform"
|
||||||
|
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
|
||||||
|
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
ImagePrice2K: g.ImagePrice2k,
|
ImagePrice2K: g.ImagePrice2k,
|
||||||
ImagePrice4K: g.ImagePrice4k,
|
ImagePrice4K: g.ImagePrice4k,
|
||||||
DefaultValidityDays: g.DefaultValidityDays,
|
DefaultValidityDays: g.DefaultValidityDays,
|
||||||
|
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: g.FallbackGroupID,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
|||||||
return &gatewayCache{rdb: rdb}
|
return &gatewayCache{rdb: rdb}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
|
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
||||||
key := stickySessionPrefix + sessionHash
|
// 格式: sticky_session:{groupID}:{sessionHash}
|
||||||
|
func buildSessionKey(groupID int64, sessionHash string) string {
|
||||||
|
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||||
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
return c.rdb.Get(ctx, key).Int64()
|
return c.rdb.Get(ctx, key).Int64()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
|
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||||
key := stickySessionPrefix + sessionHash
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
|
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||||
key := stickySessionPrefix + sessionHash
|
key := buildSessionKey(groupID, sessionHash)
|
||||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays)
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||||
|
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||||
|
|
||||||
created, err := builder.Save(ctx)
|
created, err := builder.Save(ctx)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||||
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
|
builder := r.client.Group.UpdateOneID(groupIn.ID).
|
||||||
SetName(groupIn.Name).
|
SetName(groupIn.Name).
|
||||||
SetDescription(groupIn.Description).
|
SetDescription(groupIn.Description).
|
||||||
SetPlatform(groupIn.Platform).
|
SetPlatform(groupIn.Platform).
|
||||||
@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||||
Save(ctx)
|
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||||
|
|
||||||
|
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||||
|
if groupIn.FallbackGroupID != nil {
|
||||||
|
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
|
||||||
|
} else {
|
||||||
|
builder = builder.ClearFallbackGroupID()
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := builder.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,6 +103,8 @@ type CreateGroupInput struct {
|
|||||||
ImagePrice1K *float64
|
ImagePrice1K *float64
|
||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
|
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||||
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateGroupInput struct {
|
type UpdateGroupInput struct {
|
||||||
@@ -120,6 +122,8 @@ type UpdateGroupInput struct {
|
|||||||
ImagePrice1K *float64
|
ImagePrice1K *float64
|
||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
|
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||||
|
FallbackGroupID *int64 // 降级分组 ID
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateAccountInput struct {
|
type CreateAccountInput struct {
|
||||||
@@ -516,6 +520,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||||
|
|
||||||
|
// 校验降级分组
|
||||||
|
if input.FallbackGroupID != nil {
|
||||||
|
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
group := &Group{
|
group := &Group{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
@@ -530,6 +541,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
ImagePrice1K: imagePrice1K,
|
ImagePrice1K: imagePrice1K,
|
||||||
ImagePrice2K: imagePrice2K,
|
ImagePrice2K: imagePrice2K,
|
||||||
ImagePrice4K: imagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
|
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||||
|
FallbackGroupID: input.FallbackGroupID,
|
||||||
}
|
}
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -553,6 +566,29 @@ func normalizePrice(price *float64) *float64 {
|
|||||||
return price
|
return price
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateFallbackGroup 校验降级分组的有效性
|
||||||
|
// currentGroupID: 当前分组 ID(新建时为 0)
|
||||||
|
// fallbackGroupID: 降级分组 ID
|
||||||
|
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
|
||||||
|
// 不能将自己设置为降级分组
|
||||||
|
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
|
||||||
|
return fmt.Errorf("cannot set self as fallback group")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查降级分组是否存在
|
||||||
|
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fallback group not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||||
|
if fallbackGroup.ClaudeCodeOnly {
|
||||||
|
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||||
group, err := s.groupRepo.GetByID(ctx, id)
|
group, err := s.groupRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -603,6 +639,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Claude Code 客户端限制
|
||||||
|
if input.ClaudeCodeOnly != nil {
|
||||||
|
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
|
||||||
|
}
|
||||||
|
if input.FallbackGroupID != nil {
|
||||||
|
// 校验降级分组
|
||||||
|
if *input.FallbackGroupID > 0 {
|
||||||
|
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
group.FallbackGroupID = input.FallbackGroupID
|
||||||
|
} else {
|
||||||
|
// 传入 0 或负数表示清除降级分组
|
||||||
|
group.FallbackGroupID = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
265
backend/internal/service/claude_code_validator.go
Normal file
265
backend/internal/service/claude_code_validator.go
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
|
||||||
|
// 完全学习自 claude-relay-service 项目的验证逻辑
|
||||||
|
type ClaudeCodeValidator struct{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||||
|
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||||
|
|
||||||
|
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||||
|
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||||
|
|
||||||
|
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||||
|
systemPromptThreshold = 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
// Claude Code 官方 System Prompt 模板
|
||||||
|
// 从 claude-relay-service/src/utils/contents.js 提取
|
||||||
|
var claudeCodeSystemPrompts = []string{
|
||||||
|
// claudeOtherSystemPrompt1 - Primary
|
||||||
|
"You are Claude Code, Anthropic's official CLI for Claude.",
|
||||||
|
|
||||||
|
// claudeOtherSystemPrompt3 - Agent SDK
|
||||||
|
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
|
||||||
|
|
||||||
|
// claudeOtherSystemPrompt4 - Compact Agent SDK
|
||||||
|
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
|
||||||
|
|
||||||
|
// exploreAgentSystemPrompt
|
||||||
|
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
|
||||||
|
|
||||||
|
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
|
||||||
|
"You are a helpful AI assistant tasked with summarizing conversations.",
|
||||||
|
|
||||||
|
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
|
||||||
|
"You are an interactive CLI tool that helps users",
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeCodeValidator 创建验证器实例
|
||||||
|
func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||||
|
return &ClaudeCodeValidator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 验证请求是否来自 Claude Code CLI
|
||||||
|
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||||
|
//
|
||||||
|
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||||
|
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||||
|
// Step 3: 对于 messages 路径,进行严格验证:
|
||||||
|
// - System prompt 相似度检查
|
||||||
|
// - X-App header 检查
|
||||||
|
// - anthropic-beta header 检查
|
||||||
|
// - anthropic-version header 检查
|
||||||
|
// - metadata.user_id 格式验证
|
||||||
|
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
|
||||||
|
// Step 1: User-Agent 检查
|
||||||
|
ua := r.Header.Get("User-Agent")
|
||||||
|
if !claudeCodeUAPattern.MatchString(ua) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: 非 messages 路径,只要 UA 匹配就通过
|
||||||
|
path := r.URL.Path
|
||||||
|
if !strings.Contains(path, "messages") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: messages 路径,进行严格验证
|
||||||
|
|
||||||
|
// 3.1 检查 system prompt 相似度
|
||||||
|
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3.2 检查必需的 headers(值不为空即可)
|
||||||
|
xApp := r.Header.Get("X-App")
|
||||||
|
if xApp == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicBeta := r.Header.Get("anthropic-beta")
|
||||||
|
if anthropicBeta == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicVersion := r.Header.Get("anthropic-version")
|
||||||
|
if anthropicVersion == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3.3 验证 metadata.user_id
|
||||||
|
if body == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata, ok := body["metadata"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, ok := metadata["user_id"].(string)
|
||||||
|
if !ok || userID == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !userIDPattern.MatchString(userID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
|
||||||
|
// 使用字符串相似度匹配(Dice coefficient)
|
||||||
|
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||||
|
if body == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 model 字段
|
||||||
|
if _, ok := body["model"].(string); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 system 字段
|
||||||
|
systemEntries, ok := body["system"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查每个 system entry
|
||||||
|
for _, entry := range systemEntries {
|
||||||
|
entryMap, ok := entry.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
text, ok := entryMap["text"].(string)
|
||||||
|
if !ok || text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算与所有模板的最佳相似度
|
||||||
|
bestScore := v.bestSimilarityScore(text)
|
||||||
|
if bestScore >= systemPromptThreshold {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
|
||||||
|
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
|
||||||
|
normalizedText := normalizePrompt(text)
|
||||||
|
bestScore := 0.0
|
||||||
|
|
||||||
|
for _, template := range claudeCodeSystemPrompts {
|
||||||
|
normalizedTemplate := normalizePrompt(template)
|
||||||
|
score := diceCoefficient(normalizedText, normalizedTemplate)
|
||||||
|
if score > bestScore {
|
||||||
|
bestScore = score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestScore
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizePrompt 标准化提示词文本(去除多余空白)
|
||||||
|
func normalizePrompt(text string) string {
|
||||||
|
// 将所有空白字符替换为单个空格,并去除首尾空白
|
||||||
|
return strings.Join(strings.Fields(text), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient)
|
||||||
|
// 这是 string-similarity 库使用的算法
|
||||||
|
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
|
||||||
|
func diceCoefficient(a, b string) float64 {
|
||||||
|
if a == b {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a) < 2 || len(b) < 2 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成 bigrams
|
||||||
|
bigramsA := getBigrams(a)
|
||||||
|
bigramsB := getBigrams(b)
|
||||||
|
|
||||||
|
if len(bigramsA) == 0 || len(bigramsB) == 0 {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算交集大小
|
||||||
|
intersection := 0
|
||||||
|
for bigram, countA := range bigramsA {
|
||||||
|
if countB, exists := bigramsB[bigram]; exists {
|
||||||
|
if countA < countB {
|
||||||
|
intersection += countA
|
||||||
|
} else {
|
||||||
|
intersection += countB
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算总 bigram 数量
|
||||||
|
totalA := 0
|
||||||
|
for _, count := range bigramsA {
|
||||||
|
totalA += count
|
||||||
|
}
|
||||||
|
totalB := 0
|
||||||
|
for _, count := range bigramsB {
|
||||||
|
totalB += count
|
||||||
|
}
|
||||||
|
|
||||||
|
return float64(2*intersection) / float64(totalA+totalB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBigrams 获取字符串的所有 bigrams(相邻字符对)
|
||||||
|
func getBigrams(s string) map[string]int {
|
||||||
|
bigrams := make(map[string]int)
|
||||||
|
runes := []rune(strings.ToLower(s))
|
||||||
|
|
||||||
|
for i := 0; i < len(runes)-1; i++ {
|
||||||
|
bigram := string(runes[i : i+2])
|
||||||
|
bigrams[bigram]++
|
||||||
|
}
|
||||||
|
|
||||||
|
return bigrams
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景)
|
||||||
|
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
|
||||||
|
return claudeCodeUAPattern.MatchString(ua)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
|
||||||
|
// 只要存在匹配的系统提示词就返回 true(用于宽松检测)
|
||||||
|
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||||
|
return v.hasClaudeCodeSystemPrompt(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
|
||||||
|
func IsClaudeCodeClient(ctx context.Context) bool {
|
||||||
|
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
|
||||||
|
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||||
|
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||||
|
}
|
||||||
@@ -56,6 +56,9 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
|
||||||
|
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
|
||||||
|
|
||||||
// allowedHeaders 白名单headers(参考CRS项目)
|
// allowedHeaders 白名单headers(参考CRS项目)
|
||||||
var allowedHeaders = map[string]bool{
|
var allowedHeaders = map[string]bool{
|
||||||
"accept": true,
|
"accept": true,
|
||||||
@@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{
|
|||||||
|
|
||||||
// GatewayCache defines cache operations for gateway service
|
// GatewayCache defines cache operations for gateway service
|
||||||
type GatewayCache interface {
|
type GatewayCache interface {
|
||||||
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
|
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
|
||||||
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
|
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
|
||||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
|
||||||
|
func derefGroupID(groupID *int64) int64 {
|
||||||
|
if groupID == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *groupID
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountWaitPlan struct {
|
type AccountWaitPlan struct {
|
||||||
@@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BindStickySession sets session -> account binding with standard TTL.
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||||
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||||
@@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
return nil, fmt.Errorf("get group failed: %w", err)
|
return nil, fmt.Errorf("get group failed: %w", err)
|
||||||
}
|
}
|
||||||
platform = group.Platform
|
platform = group.Platform
|
||||||
|
|
||||||
|
// 检查 Claude Code 客户端限制
|
||||||
|
if group.ClaudeCodeOnly {
|
||||||
|
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||||
|
if !isClaudeCode {
|
||||||
|
// 非 Claude Code 客户端,检查是否有降级分组
|
||||||
|
if group.FallbackGroupID != nil {
|
||||||
|
// 使用降级分组重新调度
|
||||||
|
fallbackGroupID := *group.FallbackGroupID
|
||||||
|
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||||
|
}
|
||||||
|
// 无降级分组,拒绝访问
|
||||||
|
return nil, ErrClaudeCodeOnly
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 无分组时只使用原生 anthropic 平台
|
// 无分组时只使用原生 anthropic 平台
|
||||||
platform = PlatformAnthropic
|
platform = PlatformAnthropic
|
||||||
@@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
|
||||||
stickyAccountID = accountID
|
stickyAccountID = accountID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||||
|
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
|
|
||||||
// ============ Layer 1: 粘性会话优先 ============
|
// ============ Layer 1: 粘性会话优先 ============
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||||
@@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
@@ -506,7 +539,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
|
|
||||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -556,7 +589,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: item.account,
|
Account: item.account,
|
||||||
@@ -584,7 +617,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return nil, errors.New("no available accounts")
|
return nil, errors.New("no available accounts")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||||
ordered := append([]*Account(nil), candidates...)
|
ordered := append([]*Account(nil), candidates...)
|
||||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||||
|
|
||||||
@@ -592,7 +625,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
|||||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: acc,
|
Account: acc,
|
||||||
@@ -619,6 +652,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||||
|
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||||
|
// - 有降级分组:返回降级分组的 ID
|
||||||
|
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||||
|
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||||
|
if groupID == nil {
|
||||||
|
return groupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 强制平台模式不检查 Claude Code 限制
|
||||||
|
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||||
|
return groupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get group failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !group.ClaudeCodeOnly {
|
||||||
|
return groupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组启用了 Claude Code 限制
|
||||||
|
if IsClaudeCodeClient(ctx) {
|
||||||
|
return groupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非 Claude Code 客户端,检查降级分组
|
||||||
|
if group.FallbackGroupID != nil {
|
||||||
|
return group.FallbackGroupID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrClaudeCodeOnly
|
||||||
|
}
|
||||||
|
|
||||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||||
if hasForcePlatform && forcePlatform != "" {
|
if hasForcePlatform && forcePlatform != "" {
|
||||||
@@ -738,13 +807,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
}
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
@@ -811,7 +880,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -827,14 +896,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
}
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
@@ -903,7 +972,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
cacheKey := "gemini:" + sessionHash
|
cacheKey := "gemini:" + sessionHash
|
||||||
|
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if usable {
|
if usable {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -217,7 +217,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
|||||||
}
|
}
|
||||||
|
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ type Group struct {
|
|||||||
ImagePrice2K *float64
|
ImagePrice2K *float64
|
||||||
ImagePrice4K *float64
|
ImagePrice4K *float64
|
||||||
|
|
||||||
|
// Claude Code 客户端限制
|
||||||
|
ClaudeCodeOnly bool
|
||||||
|
FallbackGroupID *int64
|
||||||
|
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|
||||||
|
|||||||
@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BindStickySession sets session -> account binding with standard TTL.
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
|
||||||
if sessionHash == "" || accountID <= 0 {
|
if sessionHash == "" || accountID <= 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectAccount selects an OpenAI account with sticky session support
|
// SelectAccount selects an OpenAI account with sticky session support
|
||||||
@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
|||||||
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||||
// 1. Check sticky session
|
// 1. Check sticky session
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// Refresh sticky session TTL
|
// Refresh sticky session TTL
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
|||||||
|
|
||||||
// 4. Set sticky session
|
// 4. Set sticky session
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
cfg := s.schedulingConfig()
|
cfg := s.schedulingConfig()
|
||||||
var stickyAccountID int64
|
var stickyAccountID int64
|
||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
|
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
|
||||||
stickyAccountID = accountID
|
stickyAccountID = accountID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
|
|
||||||
// ============ Layer 1: Sticky session ============
|
// ============ Layer 1: Sticky session ============
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: acc,
|
Account: acc,
|
||||||
@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
|
|||||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: item.account,
|
Account: item.account,
|
||||||
|
|||||||
21
backend/migrations/029_add_group_claude_code_restriction.sql
Normal file
21
backend/migrations/029_add_group_claude_code_restriction.sql
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
-- 029_add_group_claude_code_restriction.sql
|
||||||
|
-- 添加分组级别的 Claude Code 客户端限制功能
|
||||||
|
|
||||||
|
-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端
|
||||||
|
ALTER TABLE groups
|
||||||
|
ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE;
|
||||||
|
|
||||||
|
-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组
|
||||||
|
ALTER TABLE groups
|
||||||
|
ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL;
|
||||||
|
|
||||||
|
-- 添加索引优化查询
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only
|
||||||
|
ON groups(claude_code_only) WHERE deleted_at IS NULL;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id
|
||||||
|
ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL;
|
||||||
|
|
||||||
|
-- 添加字段注释
|
||||||
|
COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组';
|
||||||
|
COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID';
|
||||||
@@ -857,6 +857,15 @@ export default {
|
|||||||
imagePricing: {
|
imagePricing: {
|
||||||
title: 'Image Generation Pricing',
|
title: 'Image Generation Pricing',
|
||||||
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
|
description: 'Configure pricing for gemini-3-pro-image model. Leave empty to use default prices.'
|
||||||
|
},
|
||||||
|
claudeCode: {
|
||||||
|
title: 'Claude Code Client Restriction',
|
||||||
|
tooltip: 'When enabled, this group only allows official Claude Code clients. Non-Claude Code requests will be rejected or fallback to the specified group.',
|
||||||
|
enabled: 'Claude Code Only',
|
||||||
|
disabled: 'Allow All Clients',
|
||||||
|
fallbackGroup: 'Fallback Group',
|
||||||
|
fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.',
|
||||||
|
noFallback: 'No Fallback (Reject)'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -934,6 +934,15 @@ export default {
|
|||||||
imagePricing: {
|
imagePricing: {
|
||||||
title: '图片生成计费',
|
title: '图片生成计费',
|
||||||
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
|
description: '配置 gemini-3-pro-image 模型的图片生成价格,留空则使用默认价格'
|
||||||
|
},
|
||||||
|
claudeCode: {
|
||||||
|
title: 'Claude Code 客户端限制',
|
||||||
|
tooltip: '启用后,此分组仅允许 Claude Code 官方客户端访问。非 Claude Code 请求将被拒绝或降级到指定分组。',
|
||||||
|
enabled: '仅限 Claude Code',
|
||||||
|
disabled: '允许所有客户端',
|
||||||
|
fallbackGroup: '降级分组',
|
||||||
|
fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝',
|
||||||
|
noFallback: '不降级(直接拒绝)'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@@ -263,6 +263,9 @@ export interface Group {
|
|||||||
image_price_1k: number | null
|
image_price_1k: number | null
|
||||||
image_price_2k: number | null
|
image_price_2k: number | null
|
||||||
image_price_4k: number | null
|
image_price_4k: number | null
|
||||||
|
// Claude Code 客户端限制
|
||||||
|
claude_code_only: boolean
|
||||||
|
fallback_group_id: number | null
|
||||||
account_count?: number
|
account_count?: number
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
@@ -298,6 +301,15 @@ export interface CreateGroupRequest {
|
|||||||
platform?: GroupPlatform
|
platform?: GroupPlatform
|
||||||
rate_multiplier?: number
|
rate_multiplier?: number
|
||||||
is_exclusive?: boolean
|
is_exclusive?: boolean
|
||||||
|
subscription_type?: SubscriptionType
|
||||||
|
daily_limit_usd?: number | null
|
||||||
|
weekly_limit_usd?: number | null
|
||||||
|
monthly_limit_usd?: number | null
|
||||||
|
image_price_1k?: number | null
|
||||||
|
image_price_2k?: number | null
|
||||||
|
image_price_4k?: number | null
|
||||||
|
claude_code_only?: boolean
|
||||||
|
fallback_group_id?: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateGroupRequest {
|
export interface UpdateGroupRequest {
|
||||||
@@ -307,6 +319,15 @@ export interface UpdateGroupRequest {
|
|||||||
rate_multiplier?: number
|
rate_multiplier?: number
|
||||||
is_exclusive?: boolean
|
is_exclusive?: boolean
|
||||||
status?: 'active' | 'inactive'
|
status?: 'active' | 'inactive'
|
||||||
|
subscription_type?: SubscriptionType
|
||||||
|
daily_limit_usd?: number | null
|
||||||
|
weekly_limit_usd?: number | null
|
||||||
|
monthly_limit_usd?: number | null
|
||||||
|
image_price_1k?: number | null
|
||||||
|
image_price_2k?: number | null
|
||||||
|
image_price_4k?: number | null
|
||||||
|
claude_code_only?: boolean
|
||||||
|
fallback_group_id?: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== Account & Proxy Types ====================
|
// ==================== Account & Proxy Types ====================
|
||||||
|
|||||||
@@ -403,6 +403,62 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Claude Code 客户端限制(仅 anthropic 平台) -->
|
||||||
|
<div v-if="createForm.platform === 'anthropic'" class="border-t pt-4">
|
||||||
|
<div class="mb-1.5 flex items-center gap-1">
|
||||||
|
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.groups.claudeCode.title') }}
|
||||||
|
</label>
|
||||||
|
<!-- Help Tooltip -->
|
||||||
|
<div class="group relative inline-flex">
|
||||||
|
<Icon
|
||||||
|
name="questionCircle"
|
||||||
|
size="sm"
|
||||||
|
:stroke-width="2"
|
||||||
|
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
|
||||||
|
/>
|
||||||
|
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
|
||||||
|
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
|
||||||
|
<p class="text-xs leading-relaxed text-gray-300">
|
||||||
|
{{ t('admin.groups.claudeCode.tooltip') }}
|
||||||
|
</p>
|
||||||
|
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center gap-3">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="createForm.claude_code_only = !createForm.claude_code_only"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
|
||||||
|
createForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
|
||||||
|
createForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
<span class="text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ createForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 降级分组选择(仅当启用 claude_code_only 时显示) -->
|
||||||
|
<div v-if="createForm.claude_code_only" class="mt-3">
|
||||||
|
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
|
||||||
|
<Select
|
||||||
|
v-model="createForm.fallback_group_id"
|
||||||
|
:options="fallbackGroupOptions"
|
||||||
|
:placeholder="t('admin.groups.claudeCode.noFallback')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@@ -648,6 +704,62 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- Claude Code 客户端限制(仅 anthropic 平台) -->
|
||||||
|
<div v-if="editForm.platform === 'anthropic'" class="border-t pt-4">
|
||||||
|
<div class="mb-1.5 flex items-center gap-1">
|
||||||
|
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||||
|
{{ t('admin.groups.claudeCode.title') }}
|
||||||
|
</label>
|
||||||
|
<!-- Help Tooltip -->
|
||||||
|
<div class="group relative inline-flex">
|
||||||
|
<Icon
|
||||||
|
name="questionCircle"
|
||||||
|
size="sm"
|
||||||
|
:stroke-width="2"
|
||||||
|
class="cursor-help text-gray-400 transition-colors hover:text-primary-500 dark:text-gray-500 dark:hover:text-primary-400"
|
||||||
|
/>
|
||||||
|
<div class="pointer-events-none absolute bottom-full left-0 z-50 mb-2 w-72 opacity-0 transition-all duration-200 group-hover:pointer-events-auto group-hover:opacity-100">
|
||||||
|
<div class="rounded-lg bg-gray-900 p-3 text-white shadow-lg dark:bg-gray-800">
|
||||||
|
<p class="text-xs leading-relaxed text-gray-300">
|
||||||
|
{{ t('admin.groups.claudeCode.tooltip') }}
|
||||||
|
</p>
|
||||||
|
<div class="absolute -bottom-1.5 left-3 h-3 w-3 rotate-45 bg-gray-900 dark:bg-gray-800"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center gap-3">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
@click="editForm.claude_code_only = !editForm.claude_code_only"
|
||||||
|
:class="[
|
||||||
|
'relative inline-flex h-6 w-11 items-center rounded-full transition-colors',
|
||||||
|
editForm.claude_code_only ? 'bg-primary-500' : 'bg-gray-300 dark:bg-dark-600'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
:class="[
|
||||||
|
'inline-block h-4 w-4 transform rounded-full bg-white shadow transition-transform',
|
||||||
|
editForm.claude_code_only ? 'translate-x-6' : 'translate-x-1'
|
||||||
|
]"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
<span class="text-sm text-gray-500 dark:text-gray-400">
|
||||||
|
{{ editForm.claude_code_only ? t('admin.groups.claudeCode.enabled') : t('admin.groups.claudeCode.disabled') }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 降级分组选择(仅当启用 claude_code_only 时显示) -->
|
||||||
|
<div v-if="editForm.claude_code_only" class="mt-3">
|
||||||
|
<label class="input-label">{{ t('admin.groups.claudeCode.fallbackGroup') }}</label>
|
||||||
|
<Select
|
||||||
|
v-model="editForm.fallback_group_id"
|
||||||
|
:options="fallbackGroupOptionsForEdit"
|
||||||
|
:placeholder="t('admin.groups.claudeCode.noFallback')"
|
||||||
|
/>
|
||||||
|
<p class="input-hint">{{ t('admin.groups.claudeCode.fallbackHint') }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@@ -774,6 +886,35 @@ const subscriptionTypeOptions = computed(() => [
|
|||||||
{ value: 'subscription', label: t('admin.groups.subscription.subscription') }
|
{ value: 'subscription', label: t('admin.groups.subscription.subscription') }
|
||||||
])
|
])
|
||||||
|
|
||||||
|
// 降级分组选项(创建时)- 仅包含 anthropic 平台且未启用 claude_code_only 的分组
|
||||||
|
const fallbackGroupOptions = computed(() => {
|
||||||
|
const options: { value: number | null; label: string }[] = [
|
||||||
|
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
|
||||||
|
]
|
||||||
|
const eligibleGroups = groups.value.filter(
|
||||||
|
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active'
|
||||||
|
)
|
||||||
|
eligibleGroups.forEach((g) => {
|
||||||
|
options.push({ value: g.id, label: g.name })
|
||||||
|
})
|
||||||
|
return options
|
||||||
|
})
|
||||||
|
|
||||||
|
// 降级分组选项(编辑时)- 排除自身
|
||||||
|
const fallbackGroupOptionsForEdit = computed(() => {
|
||||||
|
const options: { value: number | null; label: string }[] = [
|
||||||
|
{ value: null, label: t('admin.groups.claudeCode.noFallback') }
|
||||||
|
]
|
||||||
|
const currentId = editingGroup.value?.id
|
||||||
|
const eligibleGroups = groups.value.filter(
|
||||||
|
(g) => g.platform === 'anthropic' && !g.claude_code_only && g.status === 'active' && g.id !== currentId
|
||||||
|
)
|
||||||
|
eligibleGroups.forEach((g) => {
|
||||||
|
options.push({ value: g.id, label: g.name })
|
||||||
|
})
|
||||||
|
return options
|
||||||
|
})
|
||||||
|
|
||||||
const groups = ref<Group[]>([])
|
const groups = ref<Group[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const searchQuery = ref('')
|
const searchQuery = ref('')
|
||||||
@@ -821,7 +962,10 @@ const createForm = reactive({
|
|||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||||
image_price_1k: null as number | null,
|
image_price_1k: null as number | null,
|
||||||
image_price_2k: null as number | null,
|
image_price_2k: null as number | null,
|
||||||
image_price_4k: null as number | null
|
image_price_4k: null as number | null,
|
||||||
|
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||||
|
claude_code_only: false,
|
||||||
|
fallback_group_id: null as number | null
|
||||||
})
|
})
|
||||||
|
|
||||||
const editForm = reactive({
|
const editForm = reactive({
|
||||||
@@ -838,7 +982,10 @@ const editForm = reactive({
|
|||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||||
image_price_1k: null as number | null,
|
image_price_1k: null as number | null,
|
||||||
image_price_2k: null as number | null,
|
image_price_2k: null as number | null,
|
||||||
image_price_4k: null as number | null
|
image_price_4k: null as number | null,
|
||||||
|
// Claude Code 客户端限制(仅 anthropic 平台使用)
|
||||||
|
claude_code_only: false,
|
||||||
|
fallback_group_id: null as number | null
|
||||||
})
|
})
|
||||||
|
|
||||||
// 根据分组类型返回不同的删除确认消息
|
// 根据分组类型返回不同的删除确认消息
|
||||||
@@ -908,6 +1055,8 @@ const closeCreateModal = () => {
|
|||||||
createForm.image_price_1k = null
|
createForm.image_price_1k = null
|
||||||
createForm.image_price_2k = null
|
createForm.image_price_2k = null
|
||||||
createForm.image_price_4k = null
|
createForm.image_price_4k = null
|
||||||
|
createForm.claude_code_only = false
|
||||||
|
createForm.fallback_group_id = null
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleCreateGroup = async () => {
|
const handleCreateGroup = async () => {
|
||||||
@@ -949,6 +1098,8 @@ const handleEdit = (group: Group) => {
|
|||||||
editForm.image_price_1k = group.image_price_1k
|
editForm.image_price_1k = group.image_price_1k
|
||||||
editForm.image_price_2k = group.image_price_2k
|
editForm.image_price_2k = group.image_price_2k
|
||||||
editForm.image_price_4k = group.image_price_4k
|
editForm.image_price_4k = group.image_price_4k
|
||||||
|
editForm.claude_code_only = group.claude_code_only || false
|
||||||
|
editForm.fallback_group_id = group.fallback_group_id
|
||||||
showEditModal.value = true
|
showEditModal.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -966,7 +1117,12 @@ const handleUpdateGroup = async () => {
|
|||||||
|
|
||||||
submitting.value = true
|
submitting.value = true
|
||||||
try {
|
try {
|
||||||
await adminAPI.groups.update(editingGroup.value.id, editForm)
|
// 转换 fallback_group_id: null -> 0 (后端使用 0 表示清除)
|
||||||
|
const payload = {
|
||||||
|
...editForm,
|
||||||
|
fallback_group_id: editForm.fallback_group_id === null ? 0 : editForm.fallback_group_id
|
||||||
|
}
|
||||||
|
await adminAPI.groups.update(editingGroup.value.id, payload)
|
||||||
appStore.showSuccess(t('admin.groups.groupUpdated'))
|
appStore.showSuccess(t('admin.groups.groupUpdated'))
|
||||||
closeEditModal()
|
closeEditModal()
|
||||||
loadGroups()
|
loadGroups()
|
||||||
|
|||||||
Reference in New Issue
Block a user