diff --git a/backend/ent/group.go b/backend/ent/group.go index 4a31442a..0d0c0538 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -55,6 +56,10 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + // 模型路由配置:模型模式 -> 优先账号ID列表 + ModelRouting map[string][]int64 `json:"model_routing,omitempty"` + // 是否启用模型路由配置 + ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldIsExclusive, group.FieldClaudeCodeOnly: + case group.FieldModelRouting: + values[i] = new([]byte) + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) @@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupID = new(int64) *_m.FallbackGroupID = value.Int64 } + case group.FieldModelRouting: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field model_routing", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil { + return fmt.Errorf("unmarshal field model_routing: %w", err) + } + } + case group.FieldModelRoutingEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i]) + } else if value.Valid { + _m.ModelRoutingEnabled = value.Bool + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -465,6 +486,12 @@ func (_m *Group) String() string { builder.WriteString("fallback_group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("model_routing=") + builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) + builder.WriteString(", ") + builder.WriteString("model_routing_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index c4317f00..d66d3edc 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -53,6 +53,10 @@ const ( FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. FieldFallbackGroupID = "fallback_group_id" + // FieldModelRouting holds the string denoting the model_routing field in the database. + FieldModelRouting = "model_routing" + // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. + FieldModelRoutingEnabled = "model_routing_enabled" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -147,6 +151,8 @@ var Columns = []string{ FieldImagePrice4k, FieldClaudeCodeOnly, FieldFallbackGroupID, + FieldModelRouting, + FieldModelRoutingEnabled, } var ( @@ -204,6 +210,8 @@ var ( DefaultDefaultValidityDays int // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool + // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. + DefaultModelRoutingEnabled bool ) // OrderOption defines the ordering options for the Group queries. @@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() } +// ByModelRoutingEnabled orders the results by the model_routing_enabled field. +func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index fb2f942f..6ce9e4c6 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) } +// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. +func ModelRoutingEnabled(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) } +// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. +func ModelRoutingIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldModelRouting)) +} + +// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field. +func ModelRoutingNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldModelRouting)) +} + +// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field. +func ModelRoutingEnabledEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) +} + +// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field. +func ModelRoutingEnabledNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 59229402..0f251e0b 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { return _c } +// SetModelRouting sets the "model_routing" field. +func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { + _c.mutation.SetModelRouting(v) + return _c +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate { + _c.mutation.SetModelRoutingEnabled(v) + return _c +} + +// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil. +func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate { + if v != nil { + _c.SetModelRoutingEnabled(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) } + if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { + v := group.DefaultModelRoutingEnabled + _c.mutation.SetModelRoutingEnabled(v) + } return nil } @@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } + if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { + return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)} + } return nil } @@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _node.FallbackGroupID = &value } + if value, ok := _c.mutation.ModelRouting(); ok { + _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) + _node.ModelRouting = value + } + if value, ok := _c.mutation.ModelRoutingEnabled(); ok { + _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) + _node.ModelRoutingEnabled = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { return u } +// SetModelRouting sets the "model_routing" field. +func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { + u.Set(group.FieldModelRouting, v) + return u +} + +// UpdateModelRouting sets the "model_routing" field to the value that was provided on create. +func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert { + u.SetExcluded(group.FieldModelRouting) + return u +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (u *GroupUpsert) ClearModelRouting() *GroupUpsert { + u.SetNull(group.FieldModelRouting) + return u +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert { + u.Set(group.FieldModelRoutingEnabled, v) + return u +} + +// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create. +func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert { + u.SetExcluded(group.FieldModelRoutingEnabled) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { }) } +// SetModelRouting sets the "model_routing" field. +func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetModelRouting(v) + }) +} + +// UpdateModelRouting sets the "model_routing" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelRouting() + }) +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearModelRouting() + }) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetModelRoutingEnabled(v) + }) +} + +// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelRoutingEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { }) } +// SetModelRouting sets the "model_routing" field. +func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetModelRouting(v) + }) +} + +// UpdateModelRouting sets the "model_routing" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelRouting() + }) +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearModelRouting() + }) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetModelRoutingEnabled(v) + }) +} + +// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateModelRoutingEnabled() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 1a6f15ec..c3cc2708 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { return _u } +// SetModelRouting sets the "model_routing" field. +func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { + _u.mutation.SetModelRouting(v) + return _u +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate { + _u.mutation.ClearModelRouting() + return _u +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate { + _u.mutation.SetModelRoutingEnabled(v) + return _u +} + +// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate { + if v != nil { + _u.SetModelRoutingEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.ModelRouting(); ok { + _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) + } + if _u.mutation.ModelRoutingCleared() { + _spec.ClearField(group.FieldModelRouting, field.TypeJSON) + } + if value, ok := _u.mutation.ModelRoutingEnabled(); ok { + _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { return _u } +// SetModelRouting sets the "model_routing" field. +func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { + _u.mutation.SetModelRouting(v) + return _u +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne { + _u.mutation.ClearModelRouting() + return _u +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne { + _u.mutation.SetModelRoutingEnabled(v) + return _u +} + +// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetModelRoutingEnabled(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.ModelRouting(); ok { + _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) + } + if _u.mutation.ModelRoutingCleared() { + _spec.ClearField(group.FieldModelRouting, field.TypeJSON) + } + if value, ok := _u.mutation.ModelRoutingEnabled(); ok { + _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d769f611..b377804f 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -226,6 +226,8 @@ var ( {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 3509efed..cd2fe8e0 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3864,6 +3864,8 @@ type GroupMutation struct { claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -4974,6 +4976,91 @@ func (m *GroupMutation) ResetFallbackGroupID() { delete(m.clearedFields, group.FieldFallbackGroupID) } +// SetModelRouting sets the "model_routing" field. +func (m *GroupMutation) SetModelRouting(value map[string][]int64) { + m.model_routing = &value +} + +// ModelRouting returns the value of the "model_routing" field in the mutation. +func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { + v := m.model_routing + if v == nil { + return + } + return *v, true +} + +// OldModelRouting returns the old "model_routing" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRouting requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) + } + return oldValue.ModelRouting, nil +} + +// ClearModelRouting clears the value of the "model_routing" field. +func (m *GroupMutation) ClearModelRouting() { + m.model_routing = nil + m.clearedFields[group.FieldModelRouting] = struct{}{} +} + +// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. +func (m *GroupMutation) ModelRoutingCleared() bool { + _, ok := m.clearedFields[group.FieldModelRouting] + return ok +} + +// ResetModelRouting resets all changes to the "model_routing" field. +func (m *GroupMutation) ResetModelRouting() { + m.model_routing = nil + delete(m.clearedFields, group.FieldModelRouting) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (m *GroupMutation) SetModelRoutingEnabled(b bool) { + m.model_routing_enabled = &b +} + +// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. +func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { + v := m.model_routing_enabled + if v == nil { + return + } + return *v, true +} + +// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) + } + return oldValue.ModelRoutingEnabled, nil +} + +// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. +func (m *GroupMutation) ResetModelRoutingEnabled() { + m.model_routing_enabled = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -5332,7 +5419,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 21) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5390,6 +5477,12 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.model_routing != nil { + fields = append(fields, group.FieldModelRouting) + } + if m.model_routing_enabled != nil { + fields = append(fields, group.FieldModelRoutingEnabled) + } return fields } @@ -5436,6 +5529,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: return m.FallbackGroupID() + case group.FieldModelRouting: + return m.ModelRouting() + case group.FieldModelRoutingEnabled: + return m.ModelRoutingEnabled() } return nil, false } @@ -5483,6 +5580,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: return m.OldFallbackGroupID(ctx) + case group.FieldModelRouting: + return m.OldModelRouting(ctx) + case group.FieldModelRoutingEnabled: + return m.OldModelRoutingEnabled(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -5625,6 +5726,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupID(v) return nil + case group.FieldModelRouting: + v, ok := value.(map[string][]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRouting(v) + return nil + case group.FieldModelRoutingEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRoutingEnabled(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -5793,6 +5908,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } + if m.FieldCleared(group.FieldModelRouting) { + fields = append(fields, group.FieldModelRouting) + } return fields } @@ -5834,6 +5952,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil + case group.FieldModelRouting: + m.ClearModelRouting() + return nil } return fmt.Errorf("unknown Group nullable field %s", name) } @@ -5899,6 +6020,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupID: m.ResetFallbackGroupID() return nil + case group.FieldModelRouting: + m.ResetModelRouting() + return nil + case group.FieldModelRoutingEnabled: + m.ResetModelRoutingEnabled() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ed13c852..0cb10775 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -280,6 +280,10 @@ func init() { groupDescClaudeCodeOnly := groupFields[14].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) + // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. + groupDescModelRoutingEnabled := groupFields[17].Descriptor() + // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. + group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d38925b1..5d0a1e9a 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). Comment("非 Claude Code 请求降级使用的分组 ID"), + + // 模型路由配置 (added by migration 040) + field.JSON("model_routing", map[string][]int64{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("模型路由配置:模型模式 -> 优先账号ID列表"), + + // 模型路由开关 (added by migration 041) + field.Bool("model_routing_enabled"). + Default(false). + Comment("是否启用模型路由配置"), } } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index a8bae35e..f6780dee 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -40,6 +40,9 @@ type CreateGroupRequest struct { ImagePrice4K *float64 `json:"image_price_4k"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 `json:"model_routing"` + ModelRoutingEnabled bool `json:"model_routing_enabled"` } // UpdateGroupRequest represents update group request @@ -60,6 +63,9 @@ type UpdateGroupRequest struct { ImagePrice4K *float64 `json:"image_price_4k"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 `json:"model_routing"` + ModelRoutingEnabled *bool `json:"model_routing_enabled"` } // List handles listing all groups with pagination @@ -149,20 +155,22 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) @@ -188,21 +196,23 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 371f4f52..df6fda0f 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -87,9 +87,11 @@ func GroupFromServiceShallow(g *service.Group) *Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, AccountCount: g.AccountCount, } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 0cbc809b..914f2b23 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -58,6 +58,10 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 `json:"model_routing"` + ModelRoutingEnabled bool `json:"model_routing_enabled"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 77a3f233..ab890844 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice4k, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, + group.FieldModelRoutingEnabled, + group.FieldModelRouting, ) }). Only(ctx) @@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 9f3c1a57..5c4d6cf4 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). - SetNillableFallbackGroupID(groupIn.FallbackGroupID) + SetNillableFallbackGroupID(groupIn.FallbackGroupID). + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + + // 设置模型路由配置 + if groupIn.ModelRouting != nil { + builder = builder.SetModelRouting(groupIn.ModelRouting) + } created, err := builder.Save(ctx) if err == nil { @@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). - SetClaudeCodeOnly(groupIn.ClaudeCodeOnly) + SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er builder = builder.ClearFallbackGroupID() } + // 处理 ModelRouting:nil 时清除,否则设置 + if groupIn.ModelRouting != nil { + builder = builder.SetModelRouting(groupIn.ModelRouting) + } else { + builder = builder.ClearModelRouting() + } + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1e32699c..c0694e4e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -106,6 +106,9 @@ type CreateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled bool // 是否启用模型路由 } type UpdateGroupInput struct { @@ -125,6 +128,9 @@ type UpdateGroupInput struct { ImagePrice4K *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled *bool // 是否启用模型路由 } type CreateAccountInput struct { @@ -581,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice4K: imagePrice4K, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, + ModelRouting: input.ModelRouting, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -709,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd } } + // 模型路由配置 + if input.ModelRouting != nil { + group.ModelRouting = input.ModelRouting + } + if input.ModelRoutingEnabled != nil { + group.ModelRoutingEnabled = *input.ModelRoutingEnabled + } + if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 7ce9a8a2..5b476dbc 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice4K *float64 `json:"image_price_4k,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + + // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. + // Only anthropic groups use these fields; others may leave them empty. + ModelRouting map[string][]int64 `json:"model_routing,omitempty"` + ModelRoutingEnabled bool `json:"model_routing_enabled"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index dfc55eeb..cf0bf586 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -221,6 +221,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice4K: apiKey.Group.ImagePrice4K, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, } } return snapshot @@ -263,6 +265,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice4K: snapshot.Group.ImagePrice4K, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, } } return apiKey diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 3314ca8d..9861264e 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -178,6 +178,10 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { Status: StatusActive, SubscriptionType: SubscriptionTypeStandard, RateMultiplier: 1, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-opus-*": {1, 2}, + }, }, }, } @@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { require.Equal(t, int64(1), apiKey.ID) require.Equal(t, int64(2), apiKey.User.ID) require.Equal(t, groupID, apiKey.Group.ID) + require.True(t, apiKey.Group.ModelRoutingEnabled) + require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting) } func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index c2dbf7c9..7673f5ef 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1053,6 +1053,60 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号") }) + t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) { + groupID := int64(1) + sessionHash := "sticky" + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{ + sessionBindings: map[string]int64{sessionHash: 1}, + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{ + groupID: { + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + Hydrated: true, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-a": {1}, + "claude-b": {2}, + }, + }, + }, + } + + cfg := testConfig() + cfg.Gateway.Scheduling.LoadBatchEnabled = true + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + cfg: cfg, + concurrencyService: nil, // legacy path + } + + result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号") + require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号") + }) + t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) { repo := &mockAccountRepoForPlatform{ accounts: []Account{ @@ -1341,6 +1395,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) ID: groupID, Platform: PlatformAnthropic, Status: StatusActive, + Hydrated: true, } groupRepo := &mockGroupRepoForGateway{ groups: map[int64]*Group{groupID: group}, @@ -1398,6 +1453,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { ID: fallbackID, Platform: PlatformAnthropic, Status: StatusActive, + Hydrated: true, } ctx = context.WithValue(ctx, ctxkey.Group, group) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 39822c68..1e3221d3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "os" "regexp" "sort" "strings" @@ -38,6 +39,21 @@ const ( maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 ) +func (s *GatewayService) debugModelRoutingEnabled() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + +func shortSessionHash(sessionHash string) string { + if sessionHash == "" { + return "" + } + if len(sessionHash) <= 8 { + return sessionHash + } + return sessionHash[:8] +} + // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( @@ -407,6 +423,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + if s.debugModelRoutingEnabled() && requestedModel != "" { + groupPlatform := "" + if group != nil { + groupPlatform = group.Platform + } + log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) if err != nil { @@ -450,6 +475,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, err } preferOAuth := platform == PlatformGemini + if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { + log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) + } accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { @@ -467,15 +495,206 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return excluded } - // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" && s.cache != nil { + // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) + accountByID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + accountByID[accounts[i].ID] = &accounts[i] + } + + // 获取模型路由配置(仅 anthropic 平台) + var routingAccountIDs []int64 + if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { + routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) + if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { + keys := make([]string, 0, len(group.ModelRouting)) + for k := range group.ModelRouting { + keys = append(keys, k) + } + sort.Strings(keys) + const maxKeys = 20 + if len(keys) > maxKeys { + keys = keys[:maxKeys] + } + log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) + } + } + } + + // ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============ + if len(routingAccountIDs) > 0 && s.concurrencyService != nil { + // 1. 过滤出路由列表中可调度的账号 + var routingCandidates []*Account + var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int + for _, routingAccountID := range routingAccountIDs { + if isExcluded(routingAccountID) { + filteredExcluded++ + continue + } + account, ok := accountByID[routingAccountID] + if !ok || !account.IsSchedulable() { + if !ok { + filteredMissing++ + } else { + filteredUnsched++ + } + continue + } + if !s.isAccountAllowedForPlatform(account, platform, useMixed) { + filteredPlatform++ + continue + } + if !account.IsSchedulableForModel(requestedModel) { + filteredModelScope++ + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) { + filteredModelMapping++ + continue + } + routingCandidates = append(routingCandidates, account) + } + + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)", + derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), + filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping) + } + + if len(routingCandidates) > 0 { + // 1.5. 在路由账号范围内检查粘性会话 + if sessionHash != "" && s.cache != nil { + stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + // 粘性账号在路由列表中,优先使用 + if stickyAccount, ok := accountByID[stickyAccountID]; ok { + if stickyAccount.IsSchedulable() && + s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && + stickyAccount.IsSchedulableForModel(requestedModel) && + (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + } + return &AccountSelectionResult{ + Account: stickyAccount, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 + } + } + } + } + + // 2. 批量获取负载信息 + routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates)) + for _, acc := range routingCandidates { + routingLoads = append(routingLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) + + // 3. 按负载感知排序 + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var routingAvailable []accountWithLoad + for _, acc := range routingCandidates { + loadInfo := routingLoadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo}) + } + } + + if len(routingAvailable) > 0 { + // 排序:优先级 > 负载率 > 最后使用时间 + sort.SliceStable(routingAvailable, func(i, j int) bool { + a, b := routingAvailable[i], routingAvailable[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + // 4. 尝试获取槽位 + for _, item := range routingAvailable { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + + // 5. 所有路由账号槽位满,返回等待计划(选择负载最低的) + acc := routingAvailable[0].account + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID) + } + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 + log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) + } + } + + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ + if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { - // 粘性命中仅在当前可调度候选集中生效。 - accountByID := make(map[int64]*Account, len(accounts)) - for i := range accounts { - accountByID[accounts[i].ID] = &accounts[i] - } account, ok := accountByID[accountID] if ok && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && @@ -687,6 +906,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* return group, nil } +func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { + if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { + return nil + } + group, err := s.resolveGroupByID(ctx, *groupID) + if err != nil || group == nil { + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) + } + return nil + } + // Preserve existing behavior: model routing only applies to anthropic groups. + if group.Platform != PlatformAnthropic { + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + } + return nil + } + ids := group.GetRoutingAccountIDs(requestedModel) + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) + } + return ids +} + func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { if groupID == nil { return nil, nil, nil @@ -868,6 +1113,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { preferOAuth := platform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing + // so switching model can switch upstream account within the same sticky session. + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + + // 2) Select an account from the routed candidates. + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !acc.IsSchedulable() { + continue + } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } + // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) @@ -886,13 +1241,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // 2. 获取可调度账号列表(单平台) - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform == "" { - hasForcePlatform = false - } - accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) + if !accountsLoaded { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } } // 3. 按优先级+最久未用选择(考虑模型支持) @@ -958,6 +1316,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { preferOAuth := nativePlatform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + } + + // 2) Select an account from the routed candidates. + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !acc.IsSchedulable() { + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if !acc.IsSchedulableForModel(requestedModel) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { @@ -979,9 +1446,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // 2. 获取可调度账号列表 - accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) + if !accountsLoaded { + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } } // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 8e8d47d6..d6d1269b 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -1,6 +1,9 @@ package service -import "time" +import ( + "strings" + "time" +) type Group struct { ID int64 @@ -27,6 +30,12 @@ type Group struct { ClaudeCodeOnly bool FallbackGroupID *int64 + // 模型路由配置 + // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") + // value: 优先账号 ID 列表 + ModelRouting map[string][]int64 + ModelRoutingEnabled bool + CreatedAt time.Time UpdatedAt time.Time @@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool { } return true } + +// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表 +// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil +func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 { + if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" { + return nil + } + + // 1. 精确匹配优先 + if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 { + return accountIDs + } + + // 2. 通配符匹配(前缀匹配) + for pattern, accountIDs := range g.ModelRouting { + if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 { + return accountIDs + } + } + + return nil +} + +// matchModelPattern 检查模型是否匹配模式 +// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514" +func matchModelPattern(pattern, model string) bool { + if pattern == model { + return true + } + + // 处理 * 通配符(仅支持末尾通配符) + if strings.HasSuffix(pattern, "*") { + prefix := strings.TrimSuffix(pattern, "*") + return strings.HasPrefix(model, prefix) + } + + return false +} diff --git a/backend/migrations/040_add_group_model_routing.sql b/backend/migrations/040_add_group_model_routing.sql new file mode 100644 index 00000000..303fcb2a --- /dev/null +++ b/backend/migrations/040_add_group_model_routing.sql @@ -0,0 +1,11 @@ +-- 040_add_group_model_routing.sql +-- 添加分组级别的模型路由配置功能 + +-- 添加 model_routing 字段:模型路由配置(JSONB 格式) +-- 格式: {"model_pattern": [account_id1, account_id2], ...} +-- 例如: {"claude-opus-*": [1, 2], "claude-sonnet-*": [3, 4, 5]} +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS model_routing JSONB DEFAULT '{}'; + +-- 添加字段注释 +COMMENT ON COLUMN groups.model_routing IS '模型路由配置:{"model_pattern": [account_id1, account_id2], ...},支持通配符匹配'; diff --git a/backend/migrations/041_add_model_routing_enabled.sql b/backend/migrations/041_add_model_routing_enabled.sql new file mode 100644 index 00000000..8691cf1f --- /dev/null +++ b/backend/migrations/041_add_model_routing_enabled.sql @@ -0,0 +1,2 @@ +-- Add model_routing_enabled field to groups table +ALTER TABLE groups ADD COLUMN model_routing_enabled BOOLEAN NOT NULL DEFAULT false; diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c31ddbc4..e4fe1bd1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -916,6 +916,26 @@ export default { fallbackGroup: 'Fallback Group', fallbackHint: 'Non-Claude Code requests will use this group. Leave empty to reject directly.', noFallback: 'No Fallback (Reject)' + }, + modelRouting: { + title: 'Model Routing', + tooltip: 'Configure specific model requests to be routed to designated accounts. Supports wildcard matching, e.g., claude-opus-* matches all opus models.', + enabled: 'Enabled', + disabled: 'Disabled', + disabledHint: 'Routing rules will only take effect when enabled', + addRule: 'Add Routing Rule', + modelPattern: 'Model Pattern', + modelPatternPlaceholder: 'claude-opus-*', + modelPatternHint: 'Supports * wildcard, e.g., claude-opus-* matches all opus models', + accounts: 'Priority Accounts', + selectAccounts: 'Select accounts', + noAccounts: 'No accounts in this group', + loadingAccounts: 'Loading accounts...', + removeRule: 'Remove Rule', + noRules: 'No routing rules', + noRulesHint: 'Add routing rules to route specific model requests to designated accounts', + searchAccountPlaceholder: 'Search accounts...', + accountsHint: 'Select accounts to prioritize for this model pattern' } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 43aeee41..35242c69 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -992,6 +992,26 @@ export default { fallbackGroup: '降级分组', fallbackHint: '非 Claude Code 请求将使用此分组,留空则直接拒绝', noFallback: '不降级(直接拒绝)' + }, + modelRouting: { + title: '模型路由配置', + tooltip: '配置特定模型请求优先路由到指定账号。支持通配符匹配,如 claude-opus-* 匹配所有 opus 模型。', + enabled: '已启用', + disabled: '已禁用', + disabledHint: '启用后,配置的路由规则才会生效', + addRule: '添加路由规则', + modelPattern: '模型模式', + modelPatternPlaceholder: 'claude-opus-*', + modelPatternHint: '支持 * 通配符,如 claude-opus-* 匹配所有 opus 模型', + accounts: '优先账号', + selectAccounts: '选择账号', + noAccounts: '此分组暂无账号', + loadingAccounts: '加载账号中...', + removeRule: '删除规则', + noRules: '暂无路由规则', + noRulesHint: '添加路由规则以将特定模型请求优先路由到指定账号', + searchAccountPlaceholder: '搜索账号...', + accountsHint: '选择此模型模式优先使用的账号' } }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 5c1e307c..523033c2 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -269,6 +269,9 @@ export interface Group { // Claude Code 客户端限制 claude_code_only: boolean fallback_group_id: number | null + // 模型路由配置(仅 anthropic 平台使用) + model_routing: Record | null + model_routing_enabled: boolean account_count?: number created_at: string updated_at: string diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index d8322154..96457172 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -460,6 +460,149 @@ + +
+
+ + +
+ +
+
+

+ {{ t('admin.groups.modelRouting.tooltip') }} +

+
+
+
+
+
+ +
+ + + {{ createForm.model_routing_enabled ? t('admin.groups.modelRouting.enabled') : t('admin.groups.modelRouting.disabled') }} + +
+

+ {{ t('admin.groups.modelRouting.disabledHint') }} +

+

+ {{ t('admin.groups.modelRouting.noRulesHint') }} +

+ +
+
+
+
+
+ + +
+
+ + +
+ + {{ account.name }} + + +
+ + +

{{ t('admin.groups.modelRouting.accountsHint') }}

+
+
+ +
+
+
+ + +
+