diff --git a/backend/ent/group.go b/backend/ent/group.go index b15ac15d..f10b50c3 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // Group is the model entity for the Group schema. @@ -76,6 +77,8 @@ type Group struct { RequirePrivacySet bool `json:"require_privacy_set,omitempty"` // 默认映射模型 ID,当账号级映射找不到时使用此值 DefaultMappedModel string `json:"default_mapped_model,omitempty"` + // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 + MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -182,7 +185,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting, group.FieldSupportedModelScopes: + case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) @@ -403,6 +406,14 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.DefaultMappedModel = value.String } + case group.FieldMessagesDispatchModelConfig: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field messages_dispatch_model_config", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.MessagesDispatchModelConfig); err != nil { + return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -585,6 +596,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("default_mapped_model=") builder.WriteString(_m.DefaultMappedModel) + builder.WriteString(", ") + builder.WriteString("messages_dispatch_model_config=") + builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 21a7c2cb..b1371630 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -8,6 +8,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/internal/domain" ) const ( @@ -73,6 +74,8 @@ const ( FieldRequirePrivacySet = "require_privacy_set" // FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database. FieldDefaultMappedModel = "default_mapped_model" + // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. + FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -177,6 +180,7 @@ var Columns = []string{ FieldRequireOauthOnly, FieldRequirePrivacySet, FieldDefaultMappedModel, + FieldMessagesDispatchModelConfig, } var ( @@ -252,6 +256,8 @@ var ( DefaultDefaultMappedModel string // DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. DefaultMappedModelValidator func(string) error + // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. + DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig ) // OrderOption defines the ordering options for the Group queries. diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index a8c30b18..f412fa40 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // GroupCreate is the builder for creating a Group entity. @@ -410,6 +411,20 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate { return _c } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_c *GroupCreate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupCreate { + _c.mutation.SetMessagesDispatchModelConfig(v) + return _c +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupCreate { + if v != nil { + _c.SetMessagesDispatchModelConfig(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -611,6 +626,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultMappedModel _c.mutation.SetDefaultMappedModel(v) } + if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { + v := group.DefaultMessagesDispatchModelConfig + _c.mutation.SetMessagesDispatchModelConfig(v) + } return nil } @@ -695,6 +714,9 @@ func (_c *GroupCreate) check() error { return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)} } } + if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { + return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} + } return nil } @@ -838,6 +860,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) _node.DefaultMappedModel = value } + if value, ok := _c.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + _node.MessagesDispatchModelConfig = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1462,6 +1488,18 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert { return u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsert) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsert { + u.Set(group.FieldMessagesDispatchModelConfig, v) + return u +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { + u.SetExcluded(group.FieldMessagesDispatchModelConfig) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2053,6 +2091,20 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne { }) } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsertOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMessagesDispatchModelConfig(v) + }) +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMessagesDispatchModelConfig() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2810,6 +2862,20 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk { }) } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (u *GroupUpsertBulk) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMessagesDispatchModelConfig(v) + }) +} + +// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMessagesDispatchModelConfig() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index aa1a83d4..7b6d6193 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // GroupUpdate is the builder for updating Group entities. @@ -552,6 +553,20 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate { return _u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_u *GroupUpdate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate { + _u.mutation.SetMessagesDispatchModelConfig(v) + return _u +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate { + if v != nil { + _u.SetMessagesDispatchModelConfig(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1012,6 +1027,9 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1843,6 +1861,20 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO return _u } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (_u *GroupUpdateOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne { + _u.mutation.SetMessagesDispatchModelConfig(v) + return _u +} + +// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne { + if v != nil { + _u.SetMessagesDispatchModelConfig(*v) + } + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2333,6 +2365,9 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.DefaultMappedModel(); ok { _spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value) } + if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { + _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 5400bf93..a7ae4af0 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -407,6 +407,7 @@ var ( {Name: "require_oauth_only", Type: field.TypeBool, Default: false}, {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, + {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index d206039a..594e5199 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -8246,6 +8246,7 @@ type GroupMutation struct { require_oauth_only *bool require_privacy_set *bool default_mapped_model *string + messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -9798,6 +9799,42 @@ func (m *GroupMutation) ResetDefaultMappedModel() { m.default_mapped_model = nil } +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { + m.messages_dispatch_model_config = &damdmc +} + +// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. +func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { + v := m.messages_dispatch_model_config + if v == nil { + return + } + return *v, true +} + +// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) + } + return oldValue.MessagesDispatchModelConfig, nil +} + +// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. +func (m *GroupMutation) ResetMessagesDispatchModelConfig() { + m.messages_dispatch_model_config = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -10156,7 +10193,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 29) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -10244,6 +10281,9 @@ func (m *GroupMutation) Fields() []string { if m.default_mapped_model != nil { fields = append(fields, group.FieldDefaultMappedModel) } + if m.messages_dispatch_model_config != nil { + fields = append(fields, group.FieldMessagesDispatchModelConfig) + } return fields } @@ -10310,6 +10350,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.RequirePrivacySet() case group.FieldDefaultMappedModel: return m.DefaultMappedModel() + case group.FieldMessagesDispatchModelConfig: + return m.MessagesDispatchModelConfig() } return nil, false } @@ -10377,6 +10419,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldRequirePrivacySet(ctx) case group.FieldDefaultMappedModel: return m.OldDefaultMappedModel(ctx) + case group.FieldMessagesDispatchModelConfig: + return m.OldMessagesDispatchModelConfig(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -10589,6 +10633,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetDefaultMappedModel(v) return nil + case group.FieldMessagesDispatchModelConfig: + v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMessagesDispatchModelConfig(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -10929,6 +10980,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldDefaultMappedModel: m.ResetDefaultMappedModel() return nil + case group.FieldMessagesDispatchModelConfig: + m.ResetMessagesDispatchModelConfig() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 803b7bc2..792f0566 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -28,6 +28,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/userattributedefinition" "github.com/Wei-Shaw/sub2api/ent/userattributevalue" "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/domain" ) // The init function reads all schema descriptors with runtime code @@ -468,6 +469,10 @@ func init() { group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) + // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. + groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() + // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. + group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 0eb89c18..d78a6898 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -141,6 +141,10 @@ func (Group) Fields() []ent.Field { MaxLen(100). Default(""). Comment("默认映射模型 ID,当账号级映射找不到时使用此值"), + field.JSON("messages_dispatch_model_config", domain.OpenAIMessagesDispatchModelConfig{}). + Default(domain.OpenAIMessagesDispatchModelConfig{}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), } } diff --git a/backend/internal/domain/openai_messages_dispatch.go b/backend/internal/domain/openai_messages_dispatch.go new file mode 100644 index 00000000..6b018f1c --- /dev/null +++ b/backend/internal/domain/openai_messages_dispatch.go @@ -0,0 +1,10 @@ +package domain + +// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages +// requests are mapped onto OpenAI/Codex models. +type OpenAIMessagesDispatchModelConfig struct { + OpusMappedModel string `json:"opus_mapped_model,omitempty"` + SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"` + HaikuMappedModel string `json:"haiku_mapped_model,omitempty"` + ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"` +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 458ed35d..8b6b056d 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -105,10 +105,11 @@ type CreateGroupRequest struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` - RequireOAuthOnly bool `json:"require_oauth_only"` - RequirePrivacySet bool `json:"require_privacy_set"` - DefaultMappedModel string `json:"default_mapped_model"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + RequireOAuthOnly bool `json:"require_oauth_only"` + RequirePrivacySet bool `json:"require_privacy_set"` + DefaultMappedModel string `json:"default_mapped_model"` + MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -139,10 +140,11 @@ type UpdateGroupRequest struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` - RequireOAuthOnly *bool `json:"require_oauth_only"` - RequirePrivacySet *bool `json:"require_privacy_set"` - DefaultMappedModel *string `json:"default_mapped_model"` + AllowMessagesDispatch *bool `json:"allow_messages_dispatch"` + RequireOAuthOnly *bool `json:"require_oauth_only"` + RequirePrivacySet *bool `json:"require_privacy_set"` + DefaultMappedModel *string `json:"default_mapped_model"` + MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -257,6 +259,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -307,6 +310,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequireOAuthOnly: req.RequireOAuthOnly, RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, + MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2eab670e..478600eb 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -133,16 +133,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - ActiveAccountCount: g.ActiveAccountCount, - RateLimitedAccountCount: g.RateLimitedAccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 82065deb..e026ca65 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -1,6 +1,10 @@ package dto -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" +) type User struct { ID int64 `json:"id"` @@ -112,7 +116,8 @@ type AdminGroup struct { MCPXMLInject bool `json:"mcp_xml_inject"` // OpenAI Messages 调度配置(仅 openai 平台使用) - DefaultMappedModel string `json:"default_mapped_model"` + DefaultMappedModel string `json:"default_mapped_model"` + MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index a075b586..1803cf30 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -58,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -124,7 +125,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch). SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). - SetDefaultMappedModel(groupIn.DefaultMappedModel) + SetDefaultMappedModel(groupIn.DefaultMappedModel). + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 8032f871..c2553eee 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -152,10 +152,11 @@ type CreateGroupInput struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - DefaultMappedModel string - RequireOAuthOnly bool - RequirePrivacySet bool + AllowMessagesDispatch bool + DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -186,10 +187,11 @@ type UpdateGroupInput struct { // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool - DefaultMappedModel *string - RequireOAuthOnly *bool - RequirePrivacySet *bool + AllowMessagesDispatch *bool + DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool + MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -908,7 +910,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequireOAuthOnly: input.RequireOAuthOnly, RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, + MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), } + sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } @@ -1135,6 +1139,10 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.DefaultMappedModel != nil { group.DefaultMappedModel = *input.DefaultMappedModel } + if input.MessagesDispatchModelConfig != nil { + group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) + } + sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 536be0b5..fa676601 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -245,6 +245,116 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { + repo := &groupRepoStubForAdmin{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "dispatch-group", + Description: "dispatch config", + Platform: PlatformOpenAI, + RateMultiplier: 1.0, + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: " gpt-5.4-high ", + SonnetMappedModel: " gpt-5.3-codex ", + HaikuMappedModel: " gpt-5.4-mini-medium ", + ExactModelMappings: map[string]string{ + " claude-sonnet-4-5-20250929 ": " gpt-5.2-high ", + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4-5-20250929": "gpt-5.2", + }, + }, repo.created.MessagesDispatchModelConfig) +} + +func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: " gpt-5.4-medium ", + ExactModelMappings: map[string]string{ + " claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ", + }, + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: "gpt-5.4", + ExactModelMappings: map[string]string{ + "claude-haiku-4-5-20251001": "gpt-5.4-mini", + }, + }, repo.updated.MessagesDispatchModelConfig) +} + +func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) { + repo := &groupRepoStubForAdmin{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "anthropic-group", + Description: "non-openai", + Platform: PlatformAnthropic, + RateMultiplier: 1.0, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4", + }, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.False(t, repo.created.AllowMessagesDispatch) + require.Empty(t, repo.created.DefaultMappedModel) + require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig) +} + +func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-openai-group", + Platform: PlatformOpenAI, + Status: StatusActive, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + SonnetMappedModel: "gpt-5.3-codex", + }, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + Platform: ptrString(PlatformAnthropic), + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, PlatformAnthropic, repo.updated.Platform) + require.False(t, repo.updated.AllowMessagesDispatch) + require.Empty(t, repo.updated.DefaultMappedModel) + require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig) +} + func TestAdminService_ListGroups_WithSearch(t *testing.T) { // 测试: // 1. search 参数正常传递到 repository 层 diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index d59af9e1..12262613 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -3,8 +3,12 @@ package service import ( "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" ) +type OpenAIMessagesDispatchModelConfig = domain.OpenAIMessagesDispatchModelConfig + type Group struct { ID int64 Name string @@ -49,10 +53,11 @@ type Group struct { SortOrder int // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) - RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) - DefaultMappedModel string + AllowMessagesDispatch bool + RequireOAuthOnly bool // 仅允许非 apikey 类型账号关联(OpenAI/Antigravity/Anthropic/Gemini) + RequirePrivacySet bool // 调度时仅允许 privacy 已成功设置的账号(OpenAI/Antigravity/Anthropic/Gemini) + DefaultMappedModel string + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/openai_messages_dispatch.go b/backend/internal/service/openai_messages_dispatch.go new file mode 100644 index 00000000..f2c1ad3c --- /dev/null +++ b/backend/internal/service/openai_messages_dispatch.go @@ -0,0 +1,100 @@ +package service + +import "strings" + +const ( + defaultOpenAIMessagesDispatchOpusMappedModel = "gpt-5.4" + defaultOpenAIMessagesDispatchSonnetMappedModel = "gpt-5.3-codex" + defaultOpenAIMessagesDispatchHaikuMappedModel = "gpt-5.4-mini" +) + +func normalizeOpenAIMessagesDispatchMappedModel(model string) string { + model = NormalizeOpenAICompatRequestedModel(strings.TrimSpace(model)) + return strings.TrimSpace(model) +} + +func normalizeOpenAIMessagesDispatchModelConfig(cfg OpenAIMessagesDispatchModelConfig) OpenAIMessagesDispatchModelConfig { + out := OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.OpusMappedModel), + SonnetMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.SonnetMappedModel), + HaikuMappedModel: normalizeOpenAIMessagesDispatchMappedModel(cfg.HaikuMappedModel), + } + + if len(cfg.ExactModelMappings) > 0 { + out.ExactModelMappings = make(map[string]string, len(cfg.ExactModelMappings)) + for requestedModel, mappedModel := range cfg.ExactModelMappings { + requestedModel = strings.TrimSpace(requestedModel) + mappedModel = normalizeOpenAIMessagesDispatchMappedModel(mappedModel) + if requestedModel == "" || mappedModel == "" { + continue + } + out.ExactModelMappings[requestedModel] = mappedModel + } + if len(out.ExactModelMappings) == 0 { + out.ExactModelMappings = nil + } + } + + return out +} + +func claudeMessagesDispatchFamily(model string) string { + normalized := strings.ToLower(strings.TrimSpace(model)) + if !strings.HasPrefix(normalized, "claude") { + return "" + } + switch { + case strings.Contains(normalized, "opus"): + return "opus" + case strings.Contains(normalized, "sonnet"): + return "sonnet" + case strings.Contains(normalized, "haiku"): + return "haiku" + default: + return "" + } +} + +func (g *Group) ResolveMessagesDispatchModel(requestedModel string) string { + if g == nil { + return "" + } + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return "" + } + + cfg := normalizeOpenAIMessagesDispatchModelConfig(g.MessagesDispatchModelConfig) + if mappedModel := strings.TrimSpace(cfg.ExactModelMappings[requestedModel]); mappedModel != "" { + return mappedModel + } + + switch claudeMessagesDispatchFamily(requestedModel) { + case "opus": + if mappedModel := strings.TrimSpace(cfg.OpusMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchOpusMappedModel + case "sonnet": + if mappedModel := strings.TrimSpace(cfg.SonnetMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchSonnetMappedModel + case "haiku": + if mappedModel := strings.TrimSpace(cfg.HaikuMappedModel); mappedModel != "" { + return mappedModel + } + return defaultOpenAIMessagesDispatchHaikuMappedModel + default: + return "" + } +} + +func sanitizeGroupMessagesDispatchFields(g *Group) { + if g == nil || g.Platform == PlatformOpenAI { + return + } + g.AllowMessagesDispatch = false + g.DefaultMappedModel = "" + g.MessagesDispatchModelConfig = OpenAIMessagesDispatchModelConfig{} +} diff --git a/backend/internal/service/openai_messages_dispatch_test.go b/backend/internal/service/openai_messages_dispatch_test.go new file mode 100644 index 00000000..a625aadd --- /dev/null +++ b/backend/internal/service/openai_messages_dispatch_test.go @@ -0,0 +1,27 @@ +package service + +import "testing" + +import "github.com/stretchr/testify/require" + +func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) { + t.Parallel() + + cfg := normalizeOpenAIMessagesDispatchModelConfig(OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: " gpt-5.4-high ", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: " gpt-5.4-mini-medium ", + ExactModelMappings: map[string]string{ + " claude-sonnet-4-5-20250929 ": " gpt-5.2-high ", + "": "gpt-5.4", + "claude-opus-4-6": " ", + }, + }) + + require.Equal(t, "gpt-5.4", cfg.OpusMappedModel) + require.Equal(t, "gpt-5.3-codex", cfg.SonnetMappedModel) + require.Equal(t, "gpt-5.4-mini", cfg.HaikuMappedModel) + require.Equal(t, map[string]string{ + "claude-sonnet-4-5-20250929": "gpt-5.2", + }, cfg.ExactModelMappings) +} diff --git a/backend/migrations/091_add_group_messages_dispatch_model_config.sql b/backend/migrations/091_add_group_messages_dispatch_model_config.sql new file mode 100644 index 00000000..8ddfcb0f --- /dev/null +++ b/backend/migrations/091_add_group_messages_dispatch_model_config.sql @@ -0,0 +1,2 @@ +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS messages_dispatch_model_config JSONB NOT NULL DEFAULT '{}'::jsonb;