diff --git a/backend/ent/group.go b/backend/ent/group.go index 5d9ae2ed..a4f52c73 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -47,6 +47,12 @@ type Group struct { MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"` // DefaultValidityDays holds the value of the "default_validity_days" field. DefaultValidityDays int `json:"default_validity_days,omitempty"` + // 是否允许该分组使用图片生成能力 + AllowImageGeneration bool `json:"allow_image_generation,omitempty"` + // 图片生成是否使用独立倍率;false 表示共享分组有效倍率 + ImageRateIndependent bool `json:"image_rate_independent,omitempty"` + // 图片生成独立倍率,仅 image_rate_independent=true 时生效 + ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"` // ImagePrice1k holds the value of the "image_price_1k" field. ImagePrice1k *float64 `json:"image_price_1k,omitempty"` // ImagePrice2k holds the value of the "image_price_2k" field. @@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { switch columns[i] { case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: + case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: values[i] = new(sql.NullInt64) @@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.DefaultValidityDays = int(value.Int64) } + case group.FieldAllowImageGeneration: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i]) + } else if value.Valid { + _m.AllowImageGeneration = value.Bool + } + case group.FieldImageRateIndependent: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i]) + } else if value.Valid { + _m.ImageRateIndependent = value.Bool + } + case group.FieldImageRateMultiplier: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i]) + } else if value.Valid { + _m.ImageRateMultiplier = value.Float64 + } case group.FieldImagePrice1k: if value, ok := values[i].(*sql.NullFloat64); !ok { return fmt.Errorf("unexpected type %T for field image_price_1k", values[i]) @@ -550,6 +574,15 @@ func (_m *Group) String() string { builder.WriteString("default_validity_days=") builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays)) builder.WriteString(", ") + builder.WriteString("allow_image_generation=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration)) + builder.WriteString(", ") + builder.WriteString("image_rate_independent=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent)) + builder.WriteString(", ") + builder.WriteString("image_rate_multiplier=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier)) + builder.WriteString(", ") if v := _m.ImagePrice1k; v != nil { builder.WriteString("image_price_1k=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 24bd9c13..4e9ba6b6 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -44,6 +44,12 @@ const ( FieldMonthlyLimitUsd = "monthly_limit_usd" // FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database. FieldDefaultValidityDays = "default_validity_days" + // FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database. + FieldAllowImageGeneration = "allow_image_generation" + // FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database. + FieldImageRateIndependent = "image_rate_independent" + // FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database. + FieldImageRateMultiplier = "image_rate_multiplier" // FieldImagePrice1k holds the string denoting the image_price_1k field in the database. FieldImagePrice1k = "image_price_1k" // FieldImagePrice2k holds the string denoting the image_price_2k field in the database. @@ -167,6 +173,9 @@ var Columns = []string{ FieldWeeklyLimitUsd, FieldMonthlyLimitUsd, FieldDefaultValidityDays, + FieldAllowImageGeneration, + FieldImageRateIndependent, + FieldImageRateMultiplier, FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, @@ -239,6 +248,12 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field. + DefaultAllowImageGeneration bool + // DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field. + DefaultImageRateIndependent bool + // DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field. + DefaultImageRateMultiplier float64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc() } +// ByAllowImageGeneration orders the results by the allow_image_generation field. +func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc() +} + +// ByImageRateIndependent orders the results by the image_rate_independent field. +func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc() +} + +// ByImageRateMultiplier orders the results by the image_rate_multiplier field. +func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc() +} + // ByImagePrice1k orders the results by the image_price_1k field. func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 2814d130..d3223a92 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v)) } +// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ. +func AllowImageGeneration(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v)) +} + +// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ. +func ImageRateIndependent(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v)) +} + +// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ. +func ImageRateMultiplier(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v)) +} + // ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ. func ImagePrice1k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) @@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group { return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v)) } +// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field. +func AllowImageGenerationEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v)) +} + +// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field. +func AllowImageGenerationNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v)) +} + +// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field. +func ImageRateIndependentEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v)) +} + +// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field. +func ImageRateIndependentNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v)) +} + +// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...)) +} + +// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...)) +} + +// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v)) +} + +// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field. +func ImageRateMultiplierLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v)) +} + // ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field. func ImagePrice1kEQ(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 20ea0a0f..44b905bd 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate { return _c } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate { + _c.mutation.SetAllowImageGeneration(v) + return _c +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate { + if v != nil { + _c.SetAllowImageGeneration(*v) + } + return _c +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate { + _c.mutation.SetImageRateIndependent(v) + return _c +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate { + if v != nil { + _c.SetImageRateIndependent(*v) + } + return _c +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate { + _c.mutation.SetImageRateMultiplier(v) + return _c +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate { + if v != nil { + _c.SetImageRateMultiplier(*v) + } + return _c +} + // SetImagePrice1k sets the "image_price_1k" field. func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate { _c.mutation.SetImagePrice1k(v) @@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.AllowImageGeneration(); !ok { + v := group.DefaultAllowImageGeneration + _c.mutation.SetAllowImageGeneration(v) + } + if _, ok := _c.mutation.ImageRateIndependent(); !ok { + v := group.DefaultImageRateIndependent + _c.mutation.SetImageRateIndependent(v) + } + if _, ok := _c.mutation.ImageRateMultiplier(); !ok { + v := group.DefaultImageRateMultiplier + _c.mutation.SetImageRateMultiplier(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.AllowImageGeneration(); !ok { + return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)} + } + if _, ok := _c.mutation.ImageRateIndependent(); !ok { + return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)} + } + if _, ok := _c.mutation.ImageRateMultiplier(); !ok { + return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value) _node.DefaultValidityDays = value } + if value, ok := _c.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + _node.AllowImageGeneration = value + } + if value, ok := _c.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + _node.ImageRateIndependent = value + } + if value, ok := _c.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + _node.ImageRateMultiplier = value + } if value, ok := _c.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) _node.ImagePrice1k = &value @@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert { return u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert { + u.Set(group.FieldAllowImageGeneration, v) + return u +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert { + u.SetExcluded(group.FieldAllowImageGeneration) + return u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert { + u.Set(group.FieldImageRateIndependent, v) + return u +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert { + u.SetExcluded(group.FieldImageRateIndependent) + return u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert { + u.Set(group.FieldImageRateMultiplier, v) + return u +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert { + u.SetExcluded(group.FieldImageRateMultiplier) + return u +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert { + u.Add(group.FieldImageRateMultiplier, v) + return u +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert { u.Set(group.FieldImagePrice1k, v) @@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne { }) } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetAllowImageGeneration(v) + }) +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowImageGeneration() + }) +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateIndependent(v) + }) +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateIndependent() + }) +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateMultiplier(v) + }) +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddImageRateMultiplier(v) + }) +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateMultiplier() + }) +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk { }) } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetAllowImageGeneration(v) + }) +} + +// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateAllowImageGeneration() + }) +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateIndependent(v) + }) +} + +// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateIndependent() + }) +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetImageRateMultiplier(v) + }) +} + +// AddImageRateMultiplier adds v to the "image_rate_multiplier" field. +func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddImageRateMultiplier(v) + }) +} + +// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateImageRateMultiplier() + }) +} + // SetImagePrice1k sets the "image_price_1k" field. func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index cc14f897..fe55982c 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate { return _u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate { + _u.mutation.SetAllowImageGeneration(v) + return _u +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate { + if v != nil { + _u.SetAllowImageGeneration(*v) + } + return _u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate { + _u.mutation.SetImageRateIndependent(v) + return _u +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate { + if v != nil { + _u.SetImageRateIndependent(*v) + } + return _u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate { + _u.mutation.ResetImageRateMultiplier() + _u.mutation.SetImageRateMultiplier(v) + return _u +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate { + if v != nil { + _u.SetImageRateMultiplier(*v) + } + return _u +} + +// AddImageRateMultiplier adds value to the "image_rate_multiplier" field. +func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate { + _u.mutation.AddImageRateMultiplier(v) + return _u +} + // SetImagePrice1k sets the "image_price_1k" field. func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate { _u.mutation.ResetImagePrice1k() @@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) } + if value, ok := _u.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedImageRateMultiplier(); ok { + _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } if value, ok := _u.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) } @@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne { return _u } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne { + _u.mutation.SetAllowImageGeneration(v) + return _u +} + +// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetAllowImageGeneration(*v) + } + return _u +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne { + _u.mutation.SetImageRateIndependent(v) + return _u +} + +// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetImageRateIndependent(*v) + } + return _u +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne { + _u.mutation.ResetImageRateMultiplier() + _u.mutation.SetImageRateMultiplier(v) + return _u +} + +// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetImageRateMultiplier(*v) + } + return _u +} + +// AddImageRateMultiplier adds value to the "image_rate_multiplier" field. +func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne { + _u.mutation.AddImageRateMultiplier(v) + return _u +} + // SetImagePrice1k sets the "image_price_1k" field. func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne { _u.mutation.ResetImagePrice1k() @@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.AddedDefaultValidityDays(); ok { _spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value) } + if value, ok := _u.mutation.AllowImageGeneration(); ok { + _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateIndependent(); ok { + _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value) + } + if value, ok := _u.mutation.ImageRateMultiplier(); ok { + _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedImageRateMultiplier(); ok { + _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value) + } if value, ok := _u.mutation.ImagePrice1k(); ok { _spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value) } diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 178ae170..525ff092 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -638,6 +638,9 @@ var ( {Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "default_validity_days", Type: field.TypeInt, Default: 30}, + {Name: "allow_image_generation", Type: field.TypeBool, Default: false}, + {Name: "image_rate_independent", Type: field.TypeBool, Default: false}, + {Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}}, {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -690,7 +693,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[25]}, + Columns: []*schema.Column{GroupsColumns[28]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index d616e4ae..13f6193d 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14764,6 +14764,10 @@ type GroupMutation struct { addmonthly_limit_usd *float64 default_validity_days *int adddefault_validity_days *int + allow_image_generation *bool + image_rate_independent *bool + image_rate_multiplier *float64 + addimage_rate_multiplier *float64 image_price_1k *float64 addimage_price_1k *float64 image_price_2k *float64 @@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() { m.adddefault_validity_days = nil } +// SetAllowImageGeneration sets the "allow_image_generation" field. +func (m *GroupMutation) SetAllowImageGeneration(b bool) { + m.allow_image_generation = &b +} + +// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation. +func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) { + v := m.allow_image_generation + if v == nil { + return + } + return *v, true +} + +// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err) + } + return oldValue.AllowImageGeneration, nil +} + +// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field. +func (m *GroupMutation) ResetAllowImageGeneration() { + m.allow_image_generation = nil +} + +// SetImageRateIndependent sets the "image_rate_independent" field. +func (m *GroupMutation) SetImageRateIndependent(b bool) { + m.image_rate_independent = &b +} + +// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation. +func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) { + v := m.image_rate_independent + if v == nil { + return + } + return *v, true +} + +// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageRateIndependent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err) + } + return oldValue.ImageRateIndependent, nil +} + +// ResetImageRateIndependent resets all changes to the "image_rate_independent" field. +func (m *GroupMutation) ResetImageRateIndependent() { + m.image_rate_independent = nil +} + +// SetImageRateMultiplier sets the "image_rate_multiplier" field. +func (m *GroupMutation) SetImageRateMultiplier(f float64) { + m.image_rate_multiplier = &f + m.addimage_rate_multiplier = nil +} + +// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation. +func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) { + v := m.image_rate_multiplier + if v == nil { + return + } + return *v, true +} + +// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err) + } + return oldValue.ImageRateMultiplier, nil +} + +// AddImageRateMultiplier adds f to the "image_rate_multiplier" field. +func (m *GroupMutation) AddImageRateMultiplier(f float64) { + if m.addimage_rate_multiplier != nil { + *m.addimage_rate_multiplier += f + } else { + m.addimage_rate_multiplier = &f + } +} + +// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation. +func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) { + v := m.addimage_rate_multiplier + if v == nil { + return + } + return *v, true +} + +// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field. +func (m *GroupMutation) ResetImageRateMultiplier() { + m.image_rate_multiplier = nil + m.addimage_rate_multiplier = nil +} + // SetImagePrice1k sets the "image_price_1k" field. func (m *GroupMutation) SetImagePrice1k(f float64) { m.image_price_1k = &f @@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 31) + fields := make([]string, 0, 34) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string { if m.default_validity_days != nil { fields = append(fields, group.FieldDefaultValidityDays) } + if m.allow_image_generation != nil { + fields = append(fields, group.FieldAllowImageGeneration) + } + if m.image_rate_independent != nil { + fields = append(fields, group.FieldImageRateIndependent) + } + if m.image_rate_multiplier != nil { + fields = append(fields, group.FieldImageRateMultiplier) + } if m.image_price_1k != nil { fields = append(fields, group.FieldImagePrice1k) } @@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.MonthlyLimitUsd() case group.FieldDefaultValidityDays: return m.DefaultValidityDays() + case group.FieldAllowImageGeneration: + return m.AllowImageGeneration() + case group.FieldImageRateIndependent: + return m.ImageRateIndependent() + case group.FieldImageRateMultiplier: + return m.ImageRateMultiplier() case group.FieldImagePrice1k: return m.ImagePrice1k() case group.FieldImagePrice2k: @@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldMonthlyLimitUsd(ctx) case group.FieldDefaultValidityDays: return m.OldDefaultValidityDays(ctx) + case group.FieldAllowImageGeneration: + return m.OldAllowImageGeneration(ctx) + case group.FieldImageRateIndependent: + return m.OldImageRateIndependent(ctx) + case group.FieldImageRateMultiplier: + return m.OldImageRateMultiplier(ctx) case group.FieldImagePrice1k: return m.OldImagePrice1k(ctx) case group.FieldImagePrice2k: @@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetDefaultValidityDays(v) return nil + case group.FieldAllowImageGeneration: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowImageGeneration(v) + return nil + case group.FieldImageRateIndependent: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageRateIndependent(v) + return nil + case group.FieldImageRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageRateMultiplier(v) + return nil case group.FieldImagePrice1k: v, ok := value.(float64) if !ok { @@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string { if m.adddefault_validity_days != nil { fields = append(fields, group.FieldDefaultValidityDays) } + if m.addimage_rate_multiplier != nil { + fields = append(fields, group.FieldImageRateMultiplier) + } if m.addimage_price_1k != nil { fields = append(fields, group.FieldImagePrice1k) } @@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedMonthlyLimitUsd() case group.FieldDefaultValidityDays: return m.AddedDefaultValidityDays() + case group.FieldImageRateMultiplier: + return m.AddedImageRateMultiplier() case group.FieldImagePrice1k: return m.AddedImagePrice1k() case group.FieldImagePrice2k: @@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddDefaultValidityDays(v) return nil + case group.FieldImageRateMultiplier: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImageRateMultiplier(v) + return nil case group.FieldImagePrice1k: v, ok := value.(float64) if !ok { @@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldDefaultValidityDays: m.ResetDefaultValidityDays() return nil + case group.FieldAllowImageGeneration: + m.ResetAllowImageGeneration() + return nil + case group.FieldImageRateIndependent: + m.ResetImageRateIndependent() + return nil + case group.FieldImageRateMultiplier: + m.ResetImageRateMultiplier() + return nil case group.FieldImagePrice1k: m.ResetImagePrice1k() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6b344a55..a282d9ba 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -803,50 +803,62 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field. + groupDescAllowImageGeneration := groupFields[11].Descriptor() + // group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field. + group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool) + // groupDescImageRateIndependent is the schema descriptor for image_rate_independent field. + groupDescImageRateIndependent := groupFields[12].Descriptor() + // group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field. + group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool) + // groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field. + groupDescImageRateMultiplier := groupFields[13].Descriptor() + // group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field. + group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[17].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[18].Descriptor() + groupDescModelRoutingEnabled := groupFields[21].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[19].Descriptor() + groupDescMcpXMLInject := groupFields[22].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[20].Descriptor() + groupDescSupportedModelScopes := groupFields[23].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[21].Descriptor() + groupDescSortOrder := groupFields[24].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) // groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field. - groupDescAllowMessagesDispatch := groupFields[22].Descriptor() + groupDescAllowMessagesDispatch := groupFields[25].Descriptor() // group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field. group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool) // groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field. - groupDescRequireOauthOnly := groupFields[23].Descriptor() + groupDescRequireOauthOnly := groupFields[26].Descriptor() // group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field. group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool) // groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field. - groupDescRequirePrivacySet := groupFields[24].Descriptor() + groupDescRequirePrivacySet := groupFields[27].Descriptor() // group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field. group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool) // groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field. - groupDescDefaultMappedModel := groupFields[25].Descriptor() + groupDescDefaultMappedModel := groupFields[28].Descriptor() // group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field. group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string) // group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save. group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error) // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field. - groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() + groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) // groupDescRpmLimit is the schema descriptor for rpm_limit field. - groupDescRpmLimit := groupFields[27].Descriptor() + groupDescRpmLimit := groupFields[30].Descriptor() // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 11f38d66..d47e8710 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field { Default(30), // 图片生成计费配置(antigravity 和 gemini 平台使用) + field.Bool("allow_image_generation"). + Default(false). + Comment("是否允许该分组使用图片生成能力"), + field.Bool("image_rate_independent"). + Default(false). + Comment("图片生成是否使用独立倍率;false 表示共享分组有效倍率"), + field.Float("image_rate_multiplier"). + SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}). + Default(1.0). + Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"), field.Float("image_price_1k"). Optional(). Nillable(). diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 87263db0..e3dc2109 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -575,6 +575,24 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +type ImageConcurrencyConfig struct { + // Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为 + Enabled bool `mapstructure:"enabled"` + // MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数,0表示不限制 + MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"` + // OverflowMode: 图片并发达到上限后的处理方式:reject/wait + OverflowMode string `mapstructure:"overflow_mode"` + // WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒) + WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"` + // MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数 + MaxWaitingRequests int `mapstructure:"max_waiting_requests"` +} + +const ( + ImageConcurrencyOverflowModeReject = "reject" + ImageConcurrencyOverflowModeWait = "wait" +) + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -604,6 +622,8 @@ type GatewayConfig struct { OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` + // ImageConcurrency: 图片生成独立并发限制配置(默认关闭) + ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -635,6 +655,10 @@ type GatewayConfig struct { StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` + // ImageStreamDataIntervalTimeout: 图片流数据间隔超时(秒),0表示禁用 + ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"` + // ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用 + ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"` // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) MaxLineSize int `mapstructure:"max_line_size"` @@ -1672,6 +1696,11 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) + viper.SetDefault("gateway.image_concurrency.enabled", false) + viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0) + viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject) + viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30) + viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) @@ -1689,6 +1718,8 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) + viper.SetDefault("gateway.image_stream_data_interval_timeout", 900) + viper.SetDefault("gateway.image_stream_keepalive_interval", 10) viper.SetDefault("gateway.max_line_size", 500*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) @@ -2239,6 +2270,21 @@ func (c *Config) Validate() error { ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy) } } + if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 { + return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative") + } + switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) { + case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait: + default: + return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s", + ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait) + } + if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 { + return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative") + } + if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 { + return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative") + } if c.Gateway.MaxIdleConns <= 0 { return fmt.Errorf("gateway.max_idle_conns must be positive") } @@ -2277,6 +2323,20 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + if c.Gateway.ImageStreamDataIntervalTimeout < 0 { + return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative") + } + if c.Gateway.ImageStreamDataIntervalTimeout != 0 && + (c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) { + return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds") + } + if c.Gateway.ImageStreamKeepaliveInterval < 0 { + return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative") + } + if c.Gateway.ImageStreamKeepaliveInterval != 0 && + (c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) { + return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds") + } // 兼容旧键 sticky_previous_response_ttl_seconds if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 6ba86aa1..a47de2f8 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 }, wantErr: "gateway.stream_data_interval_timeout must be non-negative", }, + { + name: "gateway image stream keepalive range", + mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 }, + wantErr: "gateway.image_stream_keepalive_interval", + }, + { + name: "gateway image stream keepalive negative", + mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 }, + wantErr: "gateway.image_stream_keepalive_interval must be non-negative", + }, + { + name: "gateway image stream data interval range", + mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 }, + wantErr: "gateway.image_stream_data_interval_timeout", + }, + { + name: "gateway image stream data interval negative", + mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 }, + wantErr: "gateway.image_stream_data_interval_timeout must be non-negative", + }, + { + name: "gateway image concurrency max negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 }, + wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative", + }, + { + name: "gateway image concurrency overflow mode invalid", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" }, + wantErr: "gateway.image_concurrency.overflow_mode", + }, + { + name: "gateway image concurrency wait timeout negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 }, + wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative", + }, + { + name: "gateway image concurrency max waiting negative", + mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 }, + wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative", + }, { name: "gateway max line size", mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 }, @@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) } } + +func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.StreamDataIntervalTimeout != 180 { + t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout) + } + if cfg.Gateway.StreamKeepaliveInterval != 10 { + t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval) + } + if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 { + t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout) + } + if cfg.Gateway.ImageStreamKeepaliveInterval != 10 { + t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval) + } + if cfg.Gateway.ImageConcurrency.Enabled { + t.Fatalf("image_concurrency.enabled = true, want false") + } + if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 { + t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests) + } + if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject { + t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject) + } + if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 { + t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds) + } + if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 { + t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests) + } + if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout { + t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout) + } +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 65e5ec78..3667bbcd 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -92,6 +92,9 @@ type CreateGroupRequest struct { WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` @@ -129,6 +132,9 @@ type UpdateGroupRequest struct { WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"` MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) + AllowImageGeneration *bool `json:"allow_image_generation"` + ImageRateIndependent *bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` @@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) { DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: req.ImageRateMultiplier, ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, @@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) { DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(), WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(), MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(), + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: req.ImageRateMultiplier, ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f7503c2e..2559b112 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group { DailyLimitUSD: g.DailyLimitUSD, WeeklyLimitUSD: g.WeeklyLimitUSD, MonthlyLimitUSD: g.MonthlyLimitUSD, + AllowImageGeneration: g.AllowImageGeneration, + ImageRateIndependent: g.ImageRateIndependent, + ImageRateMultiplier: g.ImageRateMultiplier, ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 5cc2f8e4..e15a916e 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -94,9 +94,12 @@ type Group struct { MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier float64 `json:"image_rate_multiplier"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` diff --git a/backend/internal/handler/image_concurrency_limiter.go b/backend/internal/handler/image_concurrency_limiter.go new file mode 100644 index 00000000..6e7cbb67 --- /dev/null +++ b/backend/internal/handler/image_concurrency_limiter.go @@ -0,0 +1,126 @@ +package handler + +import ( + "context" + "sync" + "time" +) + +type imageConcurrencyLimiter struct { + mu sync.Mutex + notify chan struct{} + limit int + active int + waiting int + enabled bool +} + +func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) { + return l.acquire(context.Background(), enabled, limit, false, 0, 0) +} + +func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) { + return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting) +} + +func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) { + if !enabled || limit <= 0 { + return nil, true + } + if ctx == nil { + ctx = context.Background() + } + if wait { + if timeout <= 0 { + return nil, false + } + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + ctx = waitCtx + } + if maxWaiting < 0 { + maxWaiting = 0 + } + for { + release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting) + if acquired { + return release, acquired + } + if !wait || notify == nil { + return nil, false + } + if !l.waitForSlot(ctx, notify) { + if waitRelease != nil { + waitRelease() + } + return nil, false + } + if waitRelease != nil { + waitRelease() + } + } +} + +func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) { + l.mu.Lock() + defer l.mu.Unlock() + + if l.notify == nil { + l.notify = make(chan struct{}) + } + if l.enabled != enabled || l.limit != limit { + l.enabled = enabled + l.limit = limit + } + if l.active < l.limit { + l.active++ + return l.releaseFunc(), true, nil, nil + } + if !wait { + return nil, false, nil, nil + } + if maxWaiting > 0 && l.waiting >= maxWaiting { + return nil, false, nil, nil + } + l.waiting++ + return nil, false, l.waiterReleaseFunc(), l.notify +} + +func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool { + select { + case <-notify: + return true + case <-ctx.Done(): + return false + } +} + +func (l *imageConcurrencyLimiter) releaseFunc() func() { + var once sync.Once + return func() { + once.Do(func() { + l.mu.Lock() + if l.active > 0 { + l.active-- + } + if l.notify != nil { + close(l.notify) + l.notify = make(chan struct{}) + } + l.mu.Unlock() + }) + } +} + +func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() { + var once sync.Once + return func() { + once.Do(func() { + l.mu.Lock() + if l.waiting > 0 { + l.waiting-- + } + l.mu.Unlock() + }) + } +} diff --git a/backend/internal/handler/image_concurrency_limiter_test.go b/backend/internal/handler/image_concurrency_limiter_test.go new file mode 100644 index 00000000..20147f16 --- /dev/null +++ b/backend/internal/handler/image_concurrency_limiter_test.go @@ -0,0 +1,230 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + + release, acquired := limiter.TryAcquire(false, 1) + + require.True(t, acquired) + require.Nil(t, release) +} + +func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + + release, acquired := limiter.TryAcquire(true, 1) + require.True(t, acquired) + require.NotNil(t, release) + + secondRelease, secondAcquired := limiter.TryAcquire(true, 1) + require.False(t, secondAcquired) + require.Nil(t, secondRelease) + + release() + thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1) + require.True(t, thirdAcquired) + require.NotNil(t, thirdRelease) + thirdRelease() +} + +func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + + acquiredCh := make(chan func(), 1) + go func() { + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, waitAcquired) + acquiredCh <- waitRelease + }() + + time.Sleep(20 * time.Millisecond) + release() + + select { + case waitRelease := <-acquiredCh: + require.NotNil(t, waitRelease) + waitRelease() + case <-time.After(time.Second): + t.Fatal("timed out waiting for image concurrency slot") + } +} + +func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1) + + require.False(t, waitAcquired) + require.Nil(t, waitRelease) +} + +func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) { + limiter := &imageConcurrencyLimiter{} + release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + waitingStarted := make(chan struct{}) + waitingDone := make(chan struct{}) + go func() { + close(waitingStarted) + waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + if waitAcquired && waitRelease != nil { + waitRelease() + } + close(waitingDone) + }() + <-waitingStarted + time.Sleep(20 * time.Millisecond) + + overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1) + + require.False(t, overflowAcquired) + require.Nil(t, overflowRelease) + release() + <-waitingDone +} + +func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil) + + h := &OpenAIGatewayHandler{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }, + }, + }, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + + blockedRelease, blocked := h.acquireImageGenerationSlot(c, false) + + require.False(t, blocked) + require.Nil(t, blockedRelease) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} + +func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) { + gin.SetMode(gin.TestMode) + body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}` + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body)) + groupID := int64(1) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: true, + }, + User: &service.User{ID: 20}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, + errorPassthroughService: nil, + cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }}}, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + rec.Body.Reset() + rec.Code = 0 + + h.Responses(c) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} + +func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) { + gin.SetMode(gin.TestMode) + body := `{"model":"gpt-5.4","input":"write code"}` + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body)) + groupID := int64(1) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: true, + }, + User: &service.User{ID: 20}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}), + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})}, + cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{ + Enabled: true, + MaxConcurrentRequests: 1, + OverflowMode: config.ImageConcurrencyOverflowModeReject, + }}}, + imageLimiter: &imageConcurrencyLimiter{}, + } + release, acquired := h.acquireImageGenerationSlot(c, false) + require.True(t, acquired) + require.NotNil(t, release) + defer release() + rec.Body.Reset() + rec.Code = 0 + + h.Responses(c) + + require.NotEqual(t, http.StatusTooManyRequests, rec.Code) + require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded") +} diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 23844508..06ab9d52 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -187,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // Pool mode: retry on the same account - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): - } - continue - } - } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) - return - } - switchCount++ - reqLog.Warn("openai_chat_completions.upstream_failover_switching", + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result", zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), + zap.Int("image_count", result.ImageCount), + zap.Error(err), ) - continue + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // Pool mode: retry on the same account + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_chat_completions.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Warn("openai_chat_completions.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Warn("openai_chat_completions.forward_failed", - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - ) - return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) @@ -242,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index b5eec393..5966c163 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + imageLimiter *imageConcurrencyLimiter maxAccountSwitches int cfg *config.Config } @@ -69,6 +70,7 @@ func NewOpenAIGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + imageLimiter: &imageConcurrencyLimiter{}, maxAccountSwitches: maxAccountSwitches, cfg: cfg, } @@ -187,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body) + if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) { + h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) + return + } + var imageReleaseFunc func() + if imageIntent { + var imageAcquired bool + imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted) + if !imageAcquired { + return + } + if imageReleaseFunc != nil { + defer imageReleaseFunc() + } + } + // 解析渠道级模型映射 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) @@ -318,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // 池模式:同账号重试 - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai.forward_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(err), + ) + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue } - continue } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) return } - switchCount++ - reqLog.Warn("openai.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - continue - } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - } - if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { - reqLog.Warn("openai.forward_failed", fields...) + reqLog.Error("openai.forward_failed", fields...) return } - reqLog.Error("openai.forward_failed", fields...) - return } if result != nil { if account.Type == service.AccountTypeOAuth { @@ -383,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -701,52 +730,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - // 池模式:同账号重试 - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai_messages.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): - } - continue - } - } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) - return - } - switchCount++ - reqLog.Warn("openai_messages.upstream_failover_switching", + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai_messages.forward_partial_error_with_image_result", zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), + zap.Int("image_count", result.ImageCount), + zap.Error(err), ) - continue + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + // 池模式:同账号重试 + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai_messages.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue + } + } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai_messages.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) + reqLog.Warn("openai_messages.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) + return } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted) - reqLog.Warn("openai_messages.forward_failed", - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - ) - return } if result != nil { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) @@ -757,16 +794,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) requestPayloadHash := service.HashUsageRequestPayload(body) + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitUsageRecordTask(func(ctx context.Context) { + h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, @@ -1114,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage()) + return + } + // 解析渠道级模型映射 channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) @@ -1257,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { }, AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { releaseTurnSlots() - if turnErr != nil || result == nil { + if turnErr != nil { + if result == nil || result.ImageCount <= 0 { + return + } + reqLog.Warn("openai.websocket_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(turnErr), + ) + } + if result == nil { return } if account.Type == service.AccountTypeOAuth { h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders) } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) - h.submitUsageRecordTask(func(taskCtx context.Context) { + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) + h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) { if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), @@ -1440,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas task(ctx) } +func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) { + if result != nil && result.ImageCount > 0 { + h.submitMandatoryUsageRecordTask(task) + return + } + h.submitUsageRecordTask(task) +} + +func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped { + return + } + logger.L().With( + zap.String("component", "handler.openai_gateway.usage"), + ).Warn("openai.usage_record_task_mandatory_sync_fallback") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.usage"), + zap.Any("panic", recovered), + ).Error("openai.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) { + if h == nil || h.cfg == nil || h.imageLimiter == nil { + return nil, true + } + imageConcurrency := h.cfg.Gateway.ImageConcurrency + wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait + release, acquired := h.imageLimiter.Acquire( + c.Request.Context(), + imageConcurrency.Enabled, + imageConcurrency.MaxConcurrentRequests, + wait, + time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second, + imageConcurrency.MaxWaitingRequests, + ) + if acquired { + return release, true + } + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted) + return nil, false +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 4d0078a7..eba701f1 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { zap.String("capability", string(parsed.RequiredCapability)), ) + if !service.GroupAllowsImageGeneration(apiKey.Group) { + h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage()) + return + } + imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted) + if !acquired { + return + } + if imageReleaseFunc != nil { + defer imageReleaseFunc() + } + if parsed.Multipart { setOpsRequestContext(c, parsed.Model, parsed.Stream, nil) } else { @@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { responseLatencyMs = forwardDurationMs - upstreamLatencyMs } service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) - if err == nil && result != nil && result.FirstTokenMs != nil { + if result != nil && result.FirstTokenMs != nil { service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) } if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - if failoverErr.RetryableOnSameAccount { - retryLimit := account.GetPoolModeRetryCount() - if sameAccountRetryCount[account.ID] < retryLimit { - sameAccountRetryCount[account.ID]++ - reqLog.Warn("openai.images.pool_mode_same_account_retry", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("retry_limit", retryLimit), - zap.Int("retry_count", sameAccountRetryCount[account.ID]), - ) - select { - case <-c.Request.Context().Done(): - return - case <-time.After(sameAccountRetryDelay): + if result != nil && result.ImageCount > 0 { + reqLog.Warn("openai.images.forward_partial_error_with_image_result", + zap.Int64("account_id", account.ID), + zap.Int("image_count", result.ImageCount), + zap.Error(err), + ) + } else { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + if failoverErr.RetryableOnSameAccount { + retryLimit := account.GetPoolModeRetryCount() + if sameAccountRetryCount[account.ID] < retryLimit { + sameAccountRetryCount[account.ID]++ + reqLog.Warn("openai.images.pool_mode_same_account_retry", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("retry_limit", retryLimit), + zap.Int("retry_count", sameAccountRetryCount[account.ID]), + ) + select { + case <-c.Request.Context().Done(): + return + case <-time.After(sameAccountRetryDelay): + } + continue } - continue } + h.gatewayService.RecordOpenAIAccountSwitch() + failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } + switchCount++ + reqLog.Warn("openai.images.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) + continue } - h.gatewayService.RecordOpenAIAccountSwitch() - failedAccountIDs[account.ID] = struct{}{} - lastFailoverErr = failoverErr - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, streamStarted) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.images.forward_failed", fields...) return } - switchCount++ - reqLog.Warn("openai.images.upstream_failover_switching", - zap.Int64("account_id", account.ID), - zap.Int("upstream_status", failoverErr.StatusCode), - zap.Int("switch_count", switchCount), - zap.Int("max_switches", maxAccountSwitches), - ) - continue - } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - fields := []zap.Field{ - zap.Int64("account_id", account.ID), - zap.Bool("fallback_error_response_written", wroteFallback), - zap.Error(err), - } - if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { - reqLog.Warn("openai.images.forward_failed", fields...) + reqLog.Error("openai.images.forward_failed", fields...) return } - reqLog.Error("openai.images.forward_failed", fields...) - return } - if result != nil { if account.Type == service.AccountTypeOAuth { h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders) @@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { if parsed.Multipart { requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed())) } + inboundEndpoint := GetInboundEndpoint(c) + upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) - h.submitUsageRecordTask(func(ctx context.Context) { + upstreamModel := "" + if result != nil { + upstreamModel = result.UpstreamModel + } + h.submitMandatoryUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, Account: account, Subscription: subscription, - InboundEndpoint: GetInboundEndpoint(c), - UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform), + InboundEndpoint: inboundEndpoint, + UpstreamEndpoint: upstreamEndpoint, UserAgent: userAgent, IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.images"), diff --git a/backend/internal/handler/openai_images_controls_test.go b/backend/internal/handler/openai_images_controls_test.go new file mode 100644 index 00000000..cebcccac --- /dev/null +++ b/backend/internal/handler/openai_images_controls_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`) + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + groupID := int64(111) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 222, + GroupID: &groupID, + Group: &service.Group{ + ID: groupID, + AllowImageGeneration: false, + }, + User: &service.User{ID: 333}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1}) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}}, + } + + h.Images(c) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String()) + require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage()) +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index 5c945815..e4c2837a 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere }) require.True(t, called.Load(), "panic 后后续任务应仍可执行") } + +func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) { + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + block := make(chan struct{}) + release := make(chan struct{}) + pool.Submit(func(ctx context.Context) { + close(block) + <-release + }) + <-block + pool.Submit(func(ctx context.Context) {}) + + var called atomic.Bool + h.submitMandatoryUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + close(release) + + require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped") +} + +func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) { + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 1, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + block := make(chan struct{}) + release := make(chan struct{}) + pool.Submit(func(ctx context.Context) { + close(block) + <-release + }) + <-block + pool.Submit(func(ctx context.Context) {}) + + var called atomic.Bool + h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) { + called.Store(true) + }) + close(release) + + require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped") +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 3a527405..68895475 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, + group.FieldAllowImageGeneration, + group.FieldImageRateIndependent, + group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, @@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group { DailyLimitUSD: g.DailyLimitUsd, WeeklyLimitUSD: g.WeeklyLimitUsd, MonthlyLimitUSD: g.MonthlyLimitUsd, + AllowImageGeneration: g.AllowImageGeneration, + ImageRateIndependent: g.ImageRateIndependent, + ImageRateMultiplier: g.ImageRateMultiplier, ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 5e16475a..112575f4 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetAllowImageGeneration(groupIn.AllowImageGeneration). + SetImageRateIndependent(groupIn.ImageRateIndependent). + SetImageRateMultiplier(groupIn.ImageRateMultiplier). SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). @@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableDailyLimitUsd(groupIn.DailyLimitUSD). SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD). SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD). + SetAllowImageGeneration(groupIn.AllowImageGeneration). + SetImageRateIndependent(groupIn.ImageRateIndependent). + SetImageRateMultiplier(groupIn.ImageRateMultiplier). SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 607b93dc..34f560fc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, + "allow_image_generation": false, + "image_rate_independent": false, + "image_rate_multiplier": 0, "claude_code_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 90ff450f..221021d8 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -230,7 +230,11 @@ func applyAccountStatsCost( if model == "" { model = requestedModel } + requestCount := 1 + if usageLog != nil && usageLog.ImageCount > 0 { + requestCount = usageLog.ImageCount + } usageLog.AccountStatsCost = resolveAccountStatsCost( - ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost, ) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index be4c23dc..793d60d8 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -189,11 +189,14 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + AllowImageGeneration bool + ImageRateIndependent bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -226,11 +229,14 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + AllowImageGeneration *bool + ImageRateIndependent *bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -1557,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + imageRateMultiplier := 1.0 + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + imageRateMultiplier = *input.ImageRateMultiplier + } // 校验降级分组 if input.FallbackGroupID != nil { @@ -1624,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn DailyLimitUSD: dailyLimit, WeeklyLimitUSD: weeklyLimit, MonthlyLimitUSD: monthlyLimit, + AllowImageGeneration: input.AllowImageGeneration, + ImageRateIndependent: input.ImageRateIndependent, + ImageRateMultiplier: imageRateMultiplier, ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, @@ -1800,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) // 图片生成计费配置:负数表示清除(使用默认价格) + if input.AllowImageGeneration != nil { + group.AllowImageGeneration = *input.AllowImageGeneration + } + if input.ImageRateIndependent != nil { + group.ImageRateIndependent = *input.ImageRateIndependent + } + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + group.ImageRateMultiplier = *input.ImageRateMultiplier + } if input.ImagePrice1K != nil { group.ImagePrice1K = normalizePrice(input.ImagePrice1K) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index eef02240..0a2020ea 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) { + imageMultiplier := 0.5 + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + AllowImageGeneration: true, + ImageRateIndependent: true, + ImageRateMultiplier: imageMultiplier, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + Description: "updated", + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.True(t, repo.updated.AllowImageGeneration) + require.True(t, repo.updated.ImageRateIndependent) + require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12) +} + +func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformOpenAI, + Status: StatusActive, + ImageRateMultiplier: 1, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + svc := &adminServiceImpl{groupRepo: repo} + negative := -0.1 + + _, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + ImageRateMultiplier: &negative, + }) + require.Error(t, err) + require.Nil(t, repo.updated) +} + func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { existingGroup := &Group{ ID: 1, diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 1a1c78b8..4432ad7d 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct { DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier float64 `json:"image_rate_multiplier"` ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 974ea66e..0f9d4214 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot +const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls type apiKeyAuthCacheConfig struct { l1Size int @@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) DailyLimitUSD: apiKey.Group.DailyLimitUSD, WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + AllowImageGeneration: apiKey.Group.AllowImageGeneration, + ImageRateIndependent: apiKey.Group.ImageRateIndependent, + ImageRateMultiplier: apiKey.Group.ImageRateMultiplier, ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, @@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho DailyLimitUSD: snapshot.Group.DailyLimitUSD, WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + AllowImageGeneration: snapshot.Group.AllowImageGeneration, + ImageRateIndependent: snapshot.Group.ImageRateIndependent, + ImageRateMultiplier: snapshot.Group.ImageRateMultiplier, ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index cb502a2e..a9c21884 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -294,8 +294,7 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } // OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。 - if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { - normalized := normalizeCodexModel(modelLower) + if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" { switch normalized { case "gpt-5.5": return s.fallbackPrices["gpt-5.5"] @@ -644,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens } func isOpenAIGPT54Model(model string) bool { - trimmed := strings.TrimSpace(strings.ToLower(model)) - // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel - // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。 - if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { - return false - } - normalized := normalizeCodexModel(trimmed) + // 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免 + // normalizeCodexModel 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o) + // 误识别为 gpt-5.4。 + normalized := normalizeKnownOpenAICodexModel(model) return normalized == "gpt-5.4" || normalized == "gpt-5.5" } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 222abd69..df3e3a0a 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12) } +func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + inputPrice float64 + outputPrice float64 + cacheRead float64 + longContext int + }{ + {model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000}, + {model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000}, + {model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0}, + {model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err) + require.NotNil(t, pricing) + require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12) + require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12) + require.Equal(t, tt.longContext, pricing.LongContextInputThreshold) + }) + } +} + func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 67d19720..b9bd992e 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -8367,6 +8367,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage groupDefault := apiKey.Group.RateMultiplier multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } + imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) // 确定计费模型 billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) @@ -8384,7 +8385,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage } // 计算费用 - cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts) + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts) // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() @@ -8396,7 +8397,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage // 创建使用日志 accountRateMultiplier := account.BillingRateMultiplier() usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, - requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) + requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) if apiKey.GroupID != nil { @@ -8450,11 +8451,12 @@ func (s *GatewayService) calculateRecordUsageCost( apiKey *APIKey, billingModel string, multiplier float64, + imageMultiplier float64, opts *recordUsageOpts, ) *CostBreakdown { // 图片生成计费 if result.ImageCount > 0 { - return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) + return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier) } // Token 计费 @@ -8495,7 +8497,8 @@ func (s *GatewayService) calculateImageCost( Model: billingModel, GroupID: &gid, Tokens: tokens, - RequestCount: 1, + RequestCount: result.ImageCount, + SizeTier: result.ImageSize, RateMultiplier: multiplier, Resolver: s.resolver, Resolved: resolved, @@ -8580,6 +8583,7 @@ func (s *GatewayService) buildRecordUsageLog( subscription *UserSubscription, requestedModel string, multiplier float64, + imageMultiplier float64, accountRateMultiplier float64, billingType int8, cacheTTLOverridden bool, @@ -8624,6 +8628,9 @@ func (s *GatewayService) buildRecordUsageLog( SubscriptionID: optionalSubscriptionID(subscription), CreatedAt: time.Now(), } + if result.ImageCount > 0 { + usageLog.RateMultiplier = imageMultiplier + } if cost != nil { usageLog.InputCost = cost.InputCost usageLog.OutputCost = cost.OutputCost diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index bb4c5aa1..f6155352 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,9 +26,12 @@ type Group struct { DefaultValidityDays int // 图片生成计费配置(antigravity 和 gemini 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 + AllowImageGeneration bool + ImageRateIndependent bool + ImageRateMultiplier float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 // Claude Code 客户端限制 ClaudeCodeOnly bool diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 87174e03..93078aa6 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct { // CreateGroupRequest 创建分组请求 type CreateGroupRequest struct { - Name string `json:"name"` - Description string `json:"description"` - RateMultiplier float64 `json:"rate_multiplier"` - IsExclusive bool `json:"is_exclusive"` + Name string `json:"name"` + Description string `json:"description"` + RateMultiplier float64 `json:"rate_multiplier"` + IsExclusive bool `json:"is_exclusive"` + AllowImageGeneration bool `json:"allow_image_generation"` + ImageRateIndependent bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` } // UpdateGroupRequest 更新分组请求 type UpdateGroupRequest struct { - Name *string `json:"name"` - Description *string `json:"description"` - RateMultiplier *float64 `json:"rate_multiplier"` - IsExclusive *bool `json:"is_exclusive"` - Status *string `json:"status"` + Name *string `json:"name"` + Description *string `json:"description"` + RateMultiplier *float64 `json:"rate_multiplier"` + IsExclusive *bool `json:"is_exclusive"` + Status *string `json:"status"` + AllowImageGeneration *bool `json:"allow_image_generation"` + ImageRateIndependent *bool `json:"image_rate_independent"` + ImageRateMultiplier *float64 `json:"image_rate_multiplier"` } // GroupService 分组管理服务 @@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC // Create 创建分组 func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) { + imageRateMultiplier := 1.0 + if req.ImageRateMultiplier != nil { + if *req.ImageRateMultiplier < 0 { + return nil, fmt.Errorf("image_rate_multiplier must be >= 0") + } + imageRateMultiplier = *req.ImageRateMultiplier + } // 检查名称是否已存在 exists, err := s.groupRepo.ExistsByName(ctx, req.Name) if err != nil { @@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro // 创建分组 group := &Group{ - Name: req.Name, - Description: req.Description, - Platform: PlatformAnthropic, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: StatusActive, - SubscriptionType: SubscriptionTypeStandard, + Name: req.Name, + Description: req.Description, + Platform: PlatformAnthropic, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + AllowImageGeneration: req.AllowImageGeneration, + ImageRateIndependent: req.ImageRateIndependent, + ImageRateMultiplier: imageRateMultiplier, } if err := s.groupRepo.Create(ctx, group); err != nil { @@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if req.Status != nil { group.Status = *req.Status } + if req.AllowImageGeneration != nil { + group.AllowImageGeneration = *req.AllowImageGeneration + } + if req.ImageRateIndependent != nil { + group.ImageRateIndependent = *req.ImageRateIndependent + } + if req.ImageRateMultiplier != nil { + if *req.ImageRateMultiplier < 0 { + return nil, fmt.Errorf("image_rate_multiplier must be >= 0") + } + group.ImageRateMultiplier = *req.ImageRateMultiplier + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) diff --git a/backend/internal/service/image_billing_multiplier.go b/backend/internal/service/image_billing_multiplier.go new file mode 100644 index 00000000..23ec5ac1 --- /dev/null +++ b/backend/internal/service/image_billing_multiplier.go @@ -0,0 +1,11 @@ +package service + +func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 { + if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent { + if apiKey.Group.ImageRateMultiplier < 0 { + return 0 + } + return apiKey.Group.ImageRateMultiplier + } + return effectiveGroupMultiplier +} diff --git a/backend/internal/service/image_generation_intent.go b/backend/internal/service/image_generation_intent.go new file mode 100644 index 00000000..b6ef1065 --- /dev/null +++ b/backend/internal/service/image_generation_intent.go @@ -0,0 +1,220 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" +) + +const ( + openAIResponsesEndpoint = "/v1/responses" + openAIResponsesCompactEndpoint = "/v1/responses/compact" + imageGenerationPermissionMessage = "Image generation is not enabled for this group" +) + +// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups. +func ImageGenerationPermissionMessage() string { + return imageGenerationPermissionMessage +} + +// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present. +func GroupAllowsImageGeneration(group *Group) bool { + return group == nil || group.AllowImageGeneration +} + +// IsImageGenerationIntent classifies requests that can produce generated images. +func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool { + if IsImageGenerationEndpoint(endpoint) { + return true + } + if isOpenAIImageGenerationModel(requestedModel) { + return true + } + if len(body) == 0 || !gjson.ValidBytes(body) { + return false + } + if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) { + return true + } + if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) { + return true + } + return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice")) +} + +// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation. +func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool { + if IsImageGenerationEndpoint(endpoint) { + return true + } + if isOpenAIImageGenerationModel(requestedModel) { + return true + } + if reqBody == nil { + return false + } + if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) { + return true + } + if hasOpenAIImageGenerationTool(reqBody) { + return true + } + return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"]) +} + +// IsImageGenerationEndpoint identifies dedicated generated-image endpoints. +func IsImageGenerationEndpoint(endpoint string) bool { + switch normalizeImageGenerationEndpoint(endpoint) { + case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits": + return true + default: + return false + } +} + +func normalizeImageGenerationEndpoint(endpoint string) string { + endpoint = strings.TrimSpace(strings.ToLower(endpoint)) + if endpoint == "" { + return "" + } + endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com") + if idx := strings.IndexByte(endpoint, '?'); idx >= 0 { + endpoint = endpoint[:idx] + } + return strings.TrimRight(endpoint, "/") +} + +func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool { + if !tools.IsArray() { + return false + } + found := false + tools.ForEach(func(_, item gjson.Result) bool { + if strings.TrimSpace(item.Get("type").String()) == "image_generation" { + found = true + return false + } + return true + }) + return found +} + +func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool { + if !choice.Exists() { + return false + } + if choice.Type == gjson.String { + return strings.TrimSpace(choice.String()) == "image_generation" + } + if !choice.IsObject() { + return false + } + if strings.TrimSpace(choice.Get("type").String()) == "image_generation" { + return true + } + if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" { + return true + } + if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" { + return true + } + return false +} + +func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool { + switch v := choice.(type) { + case string: + return strings.TrimSpace(v) == "image_generation" + case map[string]any: + if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" { + return true + } + if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" { + return true + } + if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" { + return true + } + } + return false +} + +func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey { + if c == nil { + return nil + } + v, exists := c.Get("api_key") + if !exists { + return nil + } + apiKey, _ := v.(*APIKey) + return apiKey +} + +func apiKeyGroup(apiKey *APIKey) *Group { + if apiKey == nil { + return nil + } + return apiKey.Group +} + +func cloneRequestMapForImageIntent(body []byte) map[string]any { + if len(body) == 0 { + return nil + } + var out map[string]any + if err := json.Unmarshal(body, &out); err != nil { + return nil + } + return out +} + +func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) { + imageModel := "" + imageSize := "" + hasImageTool := false + if reqBody != nil { + rawTools, _ := reqBody["tools"].([]any) + for _, rawTool := range rawTools { + toolMap, ok := rawTool.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" { + continue + } + hasImageTool = true + imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"])) + imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"])) + break + } + if imageSize == "" { + imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"])) + } + } + if imageModel == "" && reqBody != nil { + bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"])) + if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool { + imageModel = bodyModel + } + } + if imageModel == "" && hasImageTool { + imageModel = "gpt-image-2" + } + if imageModel == "" { + imageModel = strings.TrimSpace(fallbackModel) + } + sizeTier := normalizeOpenAIImageSizeTier(imageSize) + return imageModel, sizeTier, nil +} + +func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) { + reqBody := cloneRequestMapForImageIntent(body) + return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel) +} + +func isOpenAIImageBillingModelAlias(model string) bool { + normalized := strings.ToLower(strings.TrimSpace(model)) + if normalized == "" { + return false + } + return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image") +} diff --git a/backend/internal/service/image_generation_intent_test.go b/backend/internal/service/image_generation_intent_test.go new file mode 100644 index 00000000..5e7bec79 --- /dev/null +++ b/backend/internal/service/image_generation_intent_test.go @@ -0,0 +1,184 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsImageGenerationIntent(t *testing.T) { + tests := []struct { + name string + endpoint string + model string + body []byte + want bool + }{ + { + name: "images endpoint", + endpoint: "/v1/images/generations", + body: []byte(`{"model":"gpt-image-2"}`), + want: true, + }, + { + name: "image model", + endpoint: "/v1/responses", + model: "gpt-image-2", + body: []byte(`{"model":"gpt-image-2"}`), + want: true, + }, + { + name: "image tool", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`), + want: true, + }, + { + name: "image tool choice", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`), + want: true, + }, + { + name: "required tool choice alone is text", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`), + want: false, + }, + { + name: "text only gpt 5.4", + endpoint: "/v1/responses", + model: "gpt-5.4", + body: []byte(`{"model":"gpt-5.4","input":"write code"}`), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body)) + }) + } +} + +func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "mapped-image-model", imageModel) + require.Equal(t, "1K", imageSize) +} + +func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "gpt-image-2", imageModel) + require.Equal(t, "2K", imageSize) +} + +func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) { + tests := []struct { + name string + body []byte + wantTier string + }{ + { + name: "official 2k landscape", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`), + wantTier: "2K", + }, + { + name: "official 4k landscape", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`), + wantTier: "4K", + }, + { + name: "custom valid 2k", + body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`), + wantTier: "2K", + }, + { + name: "default image tool model supports flexible size", + body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`), + wantTier: "2K", + }, + { + name: "top level image size is moved into billing", + body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`), + wantTier: "2K", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model") + require.NoError(t, err) + require.NotEmpty(t, imageModel) + require.Equal(t, tt.wantTier, imageSize) + }) + } +} + +func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) { + imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody( + []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`), + "requested-model", + ) + require.NoError(t, err) + require.Equal(t, "gpt-image-1.5", imageModel) + require.Equal(t, "2K", imageSize) +} + +func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`)) + counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`)) + counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`)) + require.Equal(t, 2, counter.Count()) +} + +func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`)) + counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`)) + counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`)) + require.Equal(t, 3, counter.Count()) + + dataCounter := newOpenAIImageOutputCounter() + dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`)) + dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`)) + require.Equal(t, 3, dataCounter.Count()) +} + +func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}")) + require.Equal(t, 1, counter.Count()) +} + +func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody( + "data: {\"type\":\"image_generation.completed\",\n" + + "data: \"b64_json\":\"final-a\"}\n\n" + + "data: [DONE]\n\n", + ) + require.Equal(t, 1, counter.Count()) +} + +func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody( + "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" + + "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n", + ) + require.Equal(t, 2, counter.Count()) +} diff --git a/backend/internal/service/image_output_accounting.go b/backend/internal/service/image_output_accounting.go new file mode 100644 index 00000000..219c0c59 --- /dev/null +++ b/backend/internal/service/image_output_accounting.go @@ -0,0 +1,149 @@ +package service + +import ( + "crypto/sha256" + "encoding/hex" + "strings" + + "github.com/tidwall/gjson" +) + +type openAIImageOutputCounter struct { + seen map[string]struct{} + count int + maxDataCount int +} + +func newOpenAIImageOutputCounter() *openAIImageOutputCounter { + return &openAIImageOutputCounter{seen: make(map[string]struct{})} +} + +func (c *openAIImageOutputCounter) Count() int { + if c == nil { + return 0 + } + if c.maxDataCount > c.count { + return c.maxDataCount + } + return c.count +} + +func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) { + if c == nil || len(body) == 0 || !gjson.ValidBytes(body) { + return + } + c.addDataArray(gjson.GetBytes(body, "data")) + c.addOutputArray(gjson.GetBytes(body, "output")) + c.addOutputArray(gjson.GetBytes(body, "response.output")) +} + +func (c *openAIImageOutputCounter) AddSSEData(data []byte) { + if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) { + return + } + root := gjson.ParseBytes(data) + c.addDataArray(root.Get("data")) + eventType := strings.TrimSpace(root.Get("type").String()) + switch eventType { + case "response.output_item.done": + c.addImageOutputItem(root.Get("item")) + case "response.completed", "response.done": + c.addOutputArray(root.Get("response.output")) + case "image_generation.completed": + if item := root.Get("item"); item.Exists() { + c.addImageOutputItem(item) + return + } + if output := root.Get("output"); output.Exists() { + c.addImageOutputItem(output) + return + } + c.addImageOutputItem(root) + } +} + +func (c *openAIImageOutputCounter) AddSSEBody(body string) { + if c == nil || strings.TrimSpace(body) == "" { + return + } + forEachOpenAISSEDataPayload(body, c.AddSSEData) +} + +func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) { + if !data.IsArray() { + return + } + count := len(data.Array()) + if count > c.maxDataCount { + c.maxDataCount = count + } +} + +func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) { + if !output.IsArray() { + return + } + output.ForEach(func(_, item gjson.Result) bool { + c.addImageOutputItem(item) + return true + }) +} + +func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) { + if !item.Exists() || !item.IsObject() { + return + } + itemType := strings.TrimSpace(item.Get("type").String()) + if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" { + return + } + if strings.Contains(strings.ToLower(item.Raw), "partial_image") { + return + } + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + result = strings.TrimSpace(item.Get("b64_json").String()) + } + if result == "" { + result = strings.TrimSpace(item.Get("url").String()) + } + if result == "" && itemType != "image_generation.completed" { + return + } + key := strings.TrimSpace(item.Get("id").String()) + if key == "" { + key = strings.TrimSpace(item.Get("call_id").String()) + } + if key == "" { + key = hashOpenAIImageOutputResult(result) + } + if key == "" { + return + } + if _, exists := c.seen[key]; exists { + return + } + c.seen[key] = struct{}{} + c.count++ +} + +func hashOpenAIImageOutputResult(result string) string { + result = strings.TrimSpace(result) + if result == "" { + return "" + } + sum := sha256.Sum256([]byte(result)) + return hex.EncodeToString(sum[:]) +} + +func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int { + counter := newOpenAIImageOutputCounter() + counter.AddJSONResponse(body) + return counter.Count() +} + +func countOpenAIImageOutputsFromSSEBody(body string) int { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody(body) + return counter.Count() +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 7a0a6636..f29d3607 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "hash/fnv" + "log/slog" "math" "sort" "strconv" @@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) return nil, nil } - if !s.isAccountRequestCompatible(account, req) { + if !s.isAccountRequestCompatible(ctx, account, req) { return nil, nil } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue } - if !s.isAccountRequestCompatible(account, req) { + if !s.isAccountRequestCompatible(ctx, account, req) { continue } if !s.isAccountTransportCompatible(account, req.RequiredTransport) { @@ -822,11 +823,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { @@ -853,11 +854,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( // WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。 for _, candidate := range selectionOrder { fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) { + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { continue } if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { @@ -888,13 +889,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport) } -func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool { +func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool { if account == nil { return false } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { return false } + if req.GroupID != nil && s != nil && s.service != nil && + s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) && + s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) { + return false + } return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) } @@ -1106,6 +1112,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler( } } + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index de98b50d..f96bf81f 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -485,12 +485,14 @@ func normalizeKnownCodexModel(model string) (string, bool) { return model, true } - modelID := model - if strings.Contains(modelID, "/") { - parts := strings.Split(modelID, "/") - modelID = parts[len(parts)-1] - } + modelID := lastOpenAIModelSegment(model) + if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" { + modelID = normalized + } + if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" { + return mapped, true + } key := codexModelLookupKey(modelID) if key == "" { return "", false diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 87bb7162..3add4779 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -804,15 +804,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ "gpt-5.4": "gpt-5.4", + "gpt5.5": "gpt-5.5", + "openai/gpt5.5": "gpt-5.5", + "gpt5.4": "gpt-5.4", "gpt-5.4-high": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4", "gpt 5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", + "gpt5.4-mini": "gpt-5.4-mini", + "gpt5.4mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini", "gpt-5.3": "gpt-5.3-codex", + "gpt5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", + "gpt5.3-codex": "gpt-5.3-codex", + "gpt5.3codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt5.3codexspark": "gpt-5.3-codex-spark", "gpt 5.3 codex spark": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 4722c82d..3791c5a8 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage return &UsageBillingApplyResult{Applied: true}, nil } +func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) { + svc := &OpenAIGatewayService{} + require.Error(t, svc.RecordUsage(context.Background(), nil)) + require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{})) +} + type openAIRecordUsageUserRepoStub struct { UserRepository @@ -1081,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") } +func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_compact_openai_alias", + Model: "gpt5.5", + UpstreamModel: "gpt-5.4", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, "gpt5.5", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") + require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} + + expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{ + InputTokens: 20, + OutputTokens: 10, + }, 1.1) + require.NoError(t, err) + + err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_unpriceable_primary_upstream_fallback", + Model: "not-priceable-alias", + BillingModel: "not-priceable-alias", + UpstreamModel: "gpt-5.4", + Usage: usage, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12) + require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero") + require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_unpriceable_without_upstream", + Model: "not-priceable-alias", + Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10}, + Duration: time.Second, + }, + APIKey: &APIKey{ID: 10}, + User: &User{ID: 20}, + Account: &Account{ID: 30}, + }) + + require.Error(t, err) + require.Contains(t, err.Error(), "calculate OpenAI usage cost failed") + require.Equal(t, 0, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} @@ -1209,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) } + +func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) { + imagePrice := 0.2 + groupID := int64(121) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_shared_multiplier", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10121, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20121}, + Account: &Account{ID: 30121}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) { + imagePrice := 0.5 + userRate := 0.2 + groupID := int64(125) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest( + usageRepo, + &openAIRecordUsageUserRepoStub{}, + &openAIRecordUsageSubRepoStub{}, + &openAIUserGroupRateRepoStub{rate: &userRate}, + ) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_user_group_override", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10125, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20125}, + Account: &Account{ID: 30125}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12) +} + +func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) { + imagePrice := 0.2 + groupID := int64(122) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_independent_multiplier", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10122, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: true, + ImageRateMultiplier: 1, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 20122}, + Account: &Account{ID: 30122}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) { + groupID := int64(123) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_channel_shared", + Model: "gpt-image-2", + ImageCount: 3, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10123, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: false, + ImageRateMultiplier: 1, + }, + }, + User: &User{ID: 20123}, + Account: &Account{ID: 30123}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12) + require.Equal(t, 3, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) { + groupID := int64(124) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_channel_independent", + Model: "gpt-image-2", + ImageCount: 3, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 10124, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 0.15, + ImageRateIndependent: true, + ImageRateMultiplier: 1, + }, + }, + User: &User{ID: 20124}, + Account: &Account{ID: 30124}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12) + require.Equal(t, 3, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver { + t.Helper() + cache := newEmptyChannelCache() + cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: &price, + } + cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive} + cache.groupPlatform[groupID] = "" + cache.loadedAt = time.Now() + cs := &ChannelService{} + cs.cache.Store(cache) + return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil)) +} + +func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) { + groupID := int64(126) + billingService := NewBillingService(&config.Config{}, nil) + svc := &GatewayService{ + billingService: billingService, + resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25), + } + + cost := svc.calculateRecordUsageCost( + context.Background(), + &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"}, + &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}}, + "gemini-image", + 0.15, + 1.0, + nil, + ) + + require.NotNil(t, cost) + require.Equal(t, string(BillingModeImage), cost.BillingMode) + require.InDelta(t, 0.5, cost.TotalCost, 1e-12) + require.InDelta(t, 0.5, cost.ActualCost, 1e-12) +} + +func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) { + groupID := int64(127) + defaultPrice := 0.10 + price4K := 0.40 + cache := newEmptyChannelCache() + cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: &defaultPrice, + Intervals: []PricingInterval{{ + TierLabel: "4K", + PerRequestPrice: &price4K, + }}, + } + cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive} + cache.loadedAt = time.Now() + channelService := &ChannelService{} + channelService.cache.Store(cache) + + svc := &GatewayService{ + billingService: NewBillingService(&config.Config{}, nil), + resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)), + } + + cost := svc.calculateRecordUsageCost( + context.Background(), + &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"}, + &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}}, + "gemini-image", + 1.0, + 1.0, + nil, + ) + + require.NotNil(t, cost) + require.Equal(t, string(BillingModeImage), cost.BillingMode) + require.InDelta(t, 0.80, cost.TotalCost, 1e-12) + require.InDelta(t, 0.80, cost.ActualCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index b818fa4a..edd821ce 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -2049,6 +2049,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco promptCacheKey = strings.TrimSpace(v) } } + apiKey := getAPIKeyFromContext(c) + imageGenerationAllowed := GroupAllowsImageGeneration(nil) + if apiKey != nil { + imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group) + } + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } // Track if body needs re-serialization bodyModified := false @@ -2108,7 +2123,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } - if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) { + if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") @@ -2119,7 +2134,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") } - if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) { + if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) { bodyModified = true disablePatch() logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") @@ -2134,7 +2149,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel - if normalizeOpenAIResponsesImageOnlyModel(reqBody) { + if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) { bodyModified = true disablePatch() if model, ok := reqBody["model"].(string); ok { @@ -2355,6 +2370,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr + } + } + // Re-serialize body only if modified if bodyModified { serializedByPatch := false @@ -2592,6 +2635,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco wsAttempts, ) wsResult.UpstreamModel = upstreamModel + if wsResult.ImageCount > 0 { + wsResult.ImageSize = imageSizeTier + wsResult.BillingModel = imageBillingModel + } return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2695,6 +2742,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Handle normal response var usage *OpenAIUsage var firstTokenMs *int + imageCount := 0 if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { @@ -2702,11 +2750,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs + imageCount = streamResult.imageCount } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) + nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { return nil, err } + usage = nonStreamResult.usage + imageCount = nonStreamResult.imageCount } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) @@ -2723,7 +2774,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) serviceTier := extractOpenAIServiceTier(reqBody) - return &OpenAIForwardResult{ + forwardResult := &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, @@ -2734,7 +2785,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel + } + return forwardResult, nil } } @@ -2823,6 +2880,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } body = updatedBody + apiKey := getAPIKeyFromContext(c) + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr + } + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -2905,6 +2991,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( var usage *OpenAIUsage var firstTokenMs *int + imageCount := 0 if reqStream { result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) if err != nil { @@ -2912,11 +2999,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } usage = result.usage firstTokenMs = result.firstTokenMs + imageCount = result.imageCount } else { - usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) + result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) if err != nil { return nil, err } + usage = result.usage + imageCount = result.imageCount } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { @@ -2927,7 +3017,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( usage = &OpenAIUsage{} } - return &OpenAIForwardResult{ + forwardResult := &OpenAIForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: reqModel, @@ -2938,7 +3028,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel + } + return forwardResult, nil } func logOpenAIPassthroughInstructionsRejected( @@ -3233,6 +3329,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { type openaiStreamingResultPassthrough struct { usage *OpenAIUsage firstTokenMs *int + imageCount int +} + +type openaiNonStreamingResultPassthrough struct { + *OpenAIUsage + usage *OpenAIUsage + imageCount int } func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { @@ -3369,6 +3472,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int clientDisconnected := false sawDone := false @@ -3400,6 +3504,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( defer putSSEScannerBuf64K(scanBuf) needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) + resultWithUsage := func() *openaiStreamingResultPassthrough { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} + } for scanner.Scan() { line := scanner.Text() @@ -3419,7 +3526,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if eventType == "response.failed" { failedMessage = extractOpenAISSEErrorMessage(dataBytes) if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) } forceFlushFailedEvent = true @@ -3431,6 +3538,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( if openAIStreamEventIsTerminal(trimmedData) { sawTerminalEvent = true } + imageCounter.AddSSEData(dataBytes) lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) @@ -3460,28 +3568,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( } if err := scanner.Err(); err != nil { if sawTerminalEvent && !sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } if sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) } if errors.Is(err, bufio.ErrTooLong) { logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + return resultWithUsage(), err } if !openAIStreamClientOutputStarted(c, clientOutputStarted) { msg := "OpenAI stream disconnected before completion" if errText := strings.TrimSpace(err.Error()); errText != "" { msg += ": " + errText } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) } if clientDisconnected { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err) + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err) } logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", @@ -3489,10 +3597,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( upstreamRequestID, err, ) - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + return resultWithUsage(), fmt.Errorf("stream read error: %w", err) } if sawFailedEvent { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage) + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) } if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { logger.FromContext(ctx).With( @@ -3501,13 +3609,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( zap.String("upstream_request_id", upstreamRequestID), ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") if !openAIStreamClientOutputStarted(c, clientOutputStarted) { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, + return resultWithUsage(), s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event") + return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event") } - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( @@ -3516,7 +3624,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( c *gin.Context, originalModel string, mappedModel string, -) (*OpenAIUsage, error) { +) (*openaiNonStreamingResultPassthrough, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { return nil, err @@ -3553,14 +3661,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + }, nil } // handlePassthroughSSEToJSON converts an SSE response body into a JSON // response for the passthrough path. It mirrors handleSSEToJSON while // preserving passthrough payloads, except compact-only model remapping may // rewrite model fields back to the original requested model. -func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -3611,7 +3723,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + }, nil } func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { @@ -4025,6 +4141,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( type openaiStreamingResult struct { usage *OpenAIUsage firstTokenMs *int + imageCount int +} + +type openaiNonStreamingResult struct { + *OpenAIUsage + usage *OpenAIUsage + imageCount int } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { @@ -4058,6 +4181,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -4136,7 +4260,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp needModelReplace := originalModel != mappedModel resultWithUsage := func() *openaiStreamingResult { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} } finalizeStream := func() (*openaiStreamingResult, error) { if !sawTerminalEvent { @@ -4231,6 +4355,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp forceFlushFailedEvent = true sawFailedEvent = true } + imageCounter.AddSSEData(dataBytes) // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { @@ -4496,7 +4621,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { }, true } -func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { return nil, err @@ -4542,7 +4667,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResult{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + }, nil } func isEventStreamResponse(header http.Header) bool { @@ -4550,7 +4679,7 @@ func isEventStreamResponse(header http.Header) bool { return strings.Contains(contentType, "text/event-stream") } -func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -4602,21 +4731,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte } c.Data(resp.StatusCode, contentType, body) - return usage, nil + return &openaiNonStreamingResult{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + }, nil } func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) { - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok || data == "" || data == "[DONE]" { - continue + var terminalType string + var terminalPayload []byte + forEachOpenAISSEDataPayload(body, func(data []byte) { + if terminalPayload != nil { + return } - eventType := strings.TrimSpace(gjson.Get(data, "type").String()) + eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String()) switch eventType { case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": - return eventType, []byte(data), true + terminalType = eventType + terminalPayload = append([]byte(nil), data...) } + }) + if terminalPayload != nil { + return terminalType, terminalPayload, true } return "", nil, false } @@ -4651,21 +4788,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R } func extractCodexFinalResponse(body string) ([]byte, bool) { - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok { - continue + var finalResponse []byte + forEachOpenAISSEDataPayload(body, func(data []byte) { + if finalResponse != nil { + return } - if data == "" || data == "[DONE]" { - continue - } - eventType := gjson.Get(data, "type").String() + eventType := gjson.GetBytes(data, "type").String() if eventType == "response.done" || eventType == "response.completed" { - if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { - return []byte(response.Raw), true + if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + finalResponse = []byte(response.Raw) } } + }) + if finalResponse != nil { + return finalResponse, true } return nil, false } @@ -4677,21 +4813,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) { acc := apicompat.NewBufferedResponseAccumulator() imageOutputs := make([]json.RawMessage, 0, 1) seenImages := make(map[string]struct{}) - lines := strings.Split(bodyText, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok || data == "" || data == "[DONE]" { - continue - } - if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok { + forEachOpenAISSEDataPayload(bodyText, func(data []byte) { + if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok { imageOutputs = append(imageOutputs, imageOutput) } var event apicompat.ResponsesStreamEvent - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue + if err := json.Unmarshal(data, &event); err == nil { + acc.ProcessEvent(&event) } - acc.ProcessEvent(&event) - } + }) if !acc.HasContent() && len(imageOutputs) == 0 { return nil, false } @@ -4744,17 +4874,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} - lines := strings.Split(body, "\n") - for _, line := range lines { - data, ok := extractOpenAISSEDataLine(line) - if !ok { - continue - } - if data == "" || data == "[DONE]" { - continue - } - s.parseSSEUsageBytes([]byte(data), usage) - } + forEachOpenAISSEDataPayload(body, func(data []byte) { + s.parseSSEUsageBytes(data, usage) + }) return usage } @@ -5036,8 +5158,14 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { + if input == nil { + return errors.New("openai usage input is nil") + } result := input.Result - if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { + if result == nil { + return errors.New("openai usage result is nil") + } + if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) } @@ -5074,6 +5202,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier) } + imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) var cost *CostBreakdown var err error @@ -5087,13 +5216,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { billingModel = input.OriginalModel } + billingModels := usageBillingModelCandidates( + billingModel, + result.BillingModel, + input.ChannelMappedModel, + input.OriginalModel, + result.UpstreamModel, + result.Model, + ) serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) } - cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier) + cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier) if err != nil { - cost = &CostBreakdown{ActualCost: 0} + return err } // Determine billing type @@ -5143,7 +5280,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.TotalCost = cost.TotalCost usageLog.ActualCost = cost.ActualCost } - usageLog.RateMultiplier = multiplier + if result.ImageCount > 0 { + usageLog.RateMultiplier = imageMultiplier + } else { + usageLog.RateMultiplier = multiplier + } usageLog.AccountRateMultiplier = &accountRateMultiplier usageLog.BillingType = billingType usageLog.Stream = result.Stream @@ -5224,14 +5365,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( ctx context.Context, result *OpenAIForwardResult, apiKey *APIKey, + billingModels []string, + multiplier float64, + imageMultiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + billingModel := firstUsageBillingModel(billingModels) + if result != nil && result.ImageCount > 0 { + return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil + } + if len(billingModels) == 0 || billingModel == "" { + return nil, errors.New("openai usage billing model is empty") + } + var lastErr error + for _, candidate := range billingModels { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier) + if err == nil { + return cost, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = errors.New("no non-empty billing model candidates") + } + return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr) +} + +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost( + ctx context.Context, + apiKey *APIKey, billingModel string, multiplier float64, tokens UsageTokens, serviceTier string, ) (*CostBreakdown, error) { - if result != nil && result.ImageCount > 0 { - return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil - } if s.resolver != nil && apiKey.Group != nil { gid := apiKey.Group.ID return s.billingService.CalculateCostUnified(CostInput{ @@ -5262,7 +5434,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( Ctx: ctx, Model: billingModel, GroupID: &gid, - RequestCount: 1, + RequestCount: result.ImageCount, SizeTier: result.ImageSize, RateMultiplier: multiplier, Resolver: s.resolver, diff --git a/backend/internal/service/openai_image_generation_controls_test.go b/backend/internal/service/openai_image_generation_controls_test.go new file mode 100644 index 00000000..76dc8053 --- /dev/null +++ b/backend/internal/service/openai_image_generation_controls_test.go @@ -0,0 +1,215 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + body []byte + }{ + { + name: "image model", + body: []byte(`{"model":"gpt-image-2","input":"draw"}`), + }, + { + name: "image tool", + body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`), + }, + { + name: "image tool choice", + body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upstream := &httpUpstreamRecorder{} + svc := newOpenAIImageGenerationControlTestService(upstream) + c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, tt.body) + + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, recorder.Code) + require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String()) + require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream") + }) + } +} + +func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`)) + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, recorder.Code) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Equal(t, 0, result.ImageCount) + require.NotNil(t, upstream.lastReq) +} + +func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + allowImages bool + wantInjected bool + }{ + {name: "disabled group skips injection", allowImages: false, wantInjected: false}, + {name: "enabled group injects image tool", allowImages: true, wantInjected: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + svc := newOpenAIImageGenerationControlTestService(upstream) + c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0") + account := newOpenAIImageGenerationControlTestAccount() + + result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`)) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, upstream.lastReq) + hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists() + require.Equal(t, tt.wantInjected, hasImageTool) + instructions := gjson.GetBytes(upstream.lastBody, "instructions").String() + require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation")) + }) + } +} + +func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{}) + c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{ + "id":"resp_image_json", + "model":"gpt-5.4", + "output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}], + "usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}} + }`)), + } + + result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4") + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.imageCount) + require.NotNil(t, result.usage) + require.Equal(t, 7, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.ImageOutputTokens) +} + +func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{}) + c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n", + )), + } + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5") + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.imageCount) + require.NotNil(t, result.usage) + require.Equal(t, 11, result.usage.InputTokens) + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 4, result.usage.ImageOutputTokens) +} + +func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService { + cfg := &config.Config{} + return &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } +} + +func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) { + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", userAgent) + groupID := int64(4242) + c.Set("api_key", &APIKey{ + ID: 2424, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + AllowImageGeneration: allowImages, + RateMultiplier: 1, + ImageRateMultiplier: 1, + }, + }) + return c, recorder +} + +func newOpenAIImageGenerationControlTestAccount() *Account { + return &Account{ + ID: 5151, + Name: "openai-image-controls", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + } +} diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 3da76525..04be5164 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -16,6 +16,7 @@ import ( "net/textproto" "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool { } func normalizeOpenAIImageSizeTier(size string) string { - switch strings.ToLower(strings.TrimSpace(size)) { + trimmed := strings.TrimSpace(size) + normalized := strings.ToLower(trimmed) + switch normalized { + case "", "auto": + return "2K" case "1024x1024": return "1K" - case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto": + case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048": return "2K" - default: + case "3840x2160", "2160x3840": + return "4K" + } + width, height, ok := parseOpenAIImageSizeDimensions(trimmed) + if !ok { return "2K" } + return classifyUnknownOpenAIImageSizeTier(width, height) +} + +const ( + openAIImage2KMaxPixels = 2560 * 1440 +) + +func parseOpenAIImageSizeDimensions(size string) (int, int, bool) { + trimmed := strings.TrimSpace(size) + parts := strings.Split(strings.ToLower(trimmed), "x") + if len(parts) != 2 { + return 0, 0, false + } + width, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return 0, 0, false + } + height, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return 0, 0, false + } + if width <= 0 || height <= 0 { + return 0, 0, false + } + return width, height, true +} + +func classifyUnknownOpenAIImageSizeTier(width int, height int) string { + if height > 0 && width > openAIImage2KMaxPixels/height { + return "4K" + } + return "2K" } func (s *OpenAIGatewayService) ForwardImages( @@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( setOpsUpstreamRequestBody(c, forwardBody) } - token, _, err := s.GetAccessToken(ctx, account) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + defer releaseUpstreamCtx() + + token, _, err := s.GetAccessToken(upstreamCtx, account) if err != nil { return nil, err } - upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) + upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint) if err != nil { return nil, err } @@ -582,14 +626,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } - return s.handleErrorResponse(ctx, resp, c, account, forwardBody) + return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody) } defer func() { _ = resp.Body.Close() }() @@ -599,6 +643,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( if parsed.Stream && isEventStreamResponse(resp.Header) { streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { + if streamCount > 0 { + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: streamUsage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: ttft, + ImageCount: streamCount, + ImageSize: parsed.SizeTier, + }, err + } return nil, err } usage = streamUsage @@ -807,66 +865,205 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") } - reader := bufio.NewReader(resp.Body) usage := OpenAIUsage{} - imageCount := 0 + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int + clientDisconnected := false + lastDownstreamWriteAt := time.Now() var fallbackBody bytes.Buffer fallbackBytes := int64(0) fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg) seenSSEData := false fallbackTooLarge := false + var sseData openAISSEDataAccumulator + + processSSEData := func(dataBytes []byte) { + seenSSEData = true + fallbackBody.Reset() + fallbackBytes = 0 + mergeOpenAIUsage(&usage, dataBytes) + imageCounter.AddSSEData(dataBytes) + } + + flushSSEEvent := func() { + sseData.Flush(processSSEData) + } + + processLine := func(line []byte) { + if len(line) == 0 { + return + } + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if !clientDisconnected { + if _, writeErr := c.Writer.Write(line); writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing") + } else { + flusher.Flush() + lastDownstreamWriteAt = time.Now() + } + } + + trimmedLine := strings.TrimRight(string(line), "\r\n") + if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" { + sseData.AddLine(trimmedLine, processSSEData) + return + } + if !seenSSEData && !fallbackTooLarge { + fallbackBytes += int64(len(line)) + if fallbackBytes <= fallbackLimit { + _, _ = fallbackBody.Write(line) + } else { + fallbackTooLarge = true + fallbackBody.Reset() + } + } + } + + finalizeFallbackBody := func() { + if seenSSEData || fallbackBody.Len() == 0 { + return + } + body := bytes.TrimSpace(fallbackBody.Bytes()) + if len(body) == 0 { + return + } + mergeOpenAIUsage(&usage, body) + imageCounter.AddJSONResponse(body) + } + + streamInterval := s.openAIImageStreamDataInterval() + keepaliveInterval := s.openAIImageStreamKeepaliveInterval() + if streamInterval <= 0 && keepaliveInterval <= 0 { + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + processLine(line) + if err == io.EOF { + break + } + if err != nil { + flushSSEEvent() + return usage, imageCounter.Count(), firstTokenMs, err + } + } + flushSSEEvent() + finalizeFallbackBody() + return usage, imageCounter.Count(), firstTokenMs, nil + } + + type readEvent struct { + line []byte + err error + } + events := make(chan readEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev readEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + } + if len(line) > 0 && !sendEvent(readEvent{line: line}) { + return + } + if err == io.EOF { + return + } + if err != nil { + _ = sendEvent(readEvent{err: err}) + return + } + } + }() + defer close(done) + + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } for { - line, err := reader.ReadBytes('\n') - if len(line) > 0 { - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + select { + case ev, ok := <-events: + if !ok { + flushSSEEvent() + finalizeFallbackBody() + return usage, imageCounter.Count(), firstTokenMs, nil } - if _, writeErr := c.Writer.Write(line); writeErr != nil { - return OpenAIUsage{}, 0, firstTokenMs, writeErr + if ev.err != nil { + flushSSEEvent() + return usage, imageCounter.Count(), firstTokenMs, ev.err + } + processLine(ev.line) + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout") + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval) + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) + return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout") + case <-keepaliveCh: + if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { + continue + } + if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing") + continue } flusher.Flush() + lastDownstreamWriteAt = time.Now() + } + } +} - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok { - if data != "" && data != "[DONE]" { - seenSSEData = true - fallbackBody.Reset() - fallbackBytes = 0 - dataBytes := []byte(data) - mergeOpenAIUsage(&usage, dataBytes) - if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount { - imageCount = count - } - } - } else if !seenSSEData && !fallbackTooLarge { - fallbackBytes += int64(len(line)) - if fallbackBytes <= fallbackLimit { - _, _ = fallbackBody.Write(line) - } else { - fallbackTooLarge = true - fallbackBody.Reset() - } - } - } - if err == io.EOF { - break - } - if err != nil { - return OpenAIUsage{}, 0, firstTokenMs, err - } +func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration { + if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 { + return 0 } - if !seenSSEData && fallbackBody.Len() > 0 { - body := bytes.TrimSpace(fallbackBody.Bytes()) - if len(body) > 0 { - mergeOpenAIUsage(&usage, body) - if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount { - imageCount = count - } - } + return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second +} + +func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration { + if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 { + return 0 } - return usage, imageCount, firstTokenMs, nil + return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second } func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int { @@ -913,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) { } func extractOpenAIImageCountFromJSONBytes(body []byte) int { - if len(body) == 0 || !gjson.ValidBytes(body) { - return 0 - } - data := gjson.GetBytes(body, "data") - if data.Exists() && data.IsArray() { - return len(data.Array()) - } - return 0 + return countOpenAIResponseImageOutputsFromJSONBytes(body) } type openAIImagePointerInfo struct { diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index 64d995e1..25cd8228 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe var ( fallbackResults []openAIResponsesImageResult fallbackSeen = make(map[string]struct{}) + finalResults []openAIResponsesImageResult + finalMeta openAIResponsesImageResult + collectErr error createdAt int64 usageRaw []byte foundFinal bool responseMeta openAIResponsesImageResult ) - for _, line := range bytes.Split(body, []byte("\n")) { - line = bytes.TrimRight(line, "\r") - data, ok := extractOpenAISSEDataLine(string(line)) - if !ok || data == "" || data == "[DONE]" { - continue + forEachOpenAISSEDataPayload(string(body), func(payload []byte) { + if collectErr != nil || len(finalResults) > 0 { + return } - payload := []byte(data) if !gjson.ValidBytes(payload) { - continue + return } if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { mergeOpenAIResponsesImageMeta(&responseMeta, meta) @@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe case "response.output_item.done": result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) if err != nil { - return nil, 0, nil, openAIResponsesImageResult{}, false, err + collectErr = err + return } if ok { mergeOpenAIResponsesImageMeta(&result, responseMeta) @@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe case "response.completed": results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) if err != nil { - return nil, 0, nil, openAIResponsesImageResult{}, false, err + collectErr = err + return } foundFinal = true if completedAt > 0 { @@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe } if len(results) > 0 { mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) - return results, createdAt, usageRaw, firstMeta, true, nil + finalResults = results + finalMeta = firstMeta + return } if len(fallbackResults) > 0 { firstMeta = fallbackResults[0] mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) - return fallbackResults, createdAt, usageRaw, firstMeta, true, nil + finalResults = fallbackResults + finalMeta = firstMeta + return } } + }) + if collectErr != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr + } + if len(finalResults) > 0 { + return finalResults, createdAt, usageRaw, finalMeta, true, nil } if len(fallbackResults) > 0 { @@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus return nil } +func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent( + c *gin.Context, + flusher http.Flusher, + clientDisconnected *bool, + lastWriteAt *time.Time, + eventName string, + payload []byte, +) bool { + if clientDisconnected != nil && *clientDisconnected { + return false + } + if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil { + if clientDisconnected != nil { + *clientDisconnected = true + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing") + return false + } + if lastWriteAt != nil { + *lastWriteAt = time.Now() + } + return true +} + func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( resp *http.Response, c *gin.Context, @@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( } var usage OpenAIUsage - for _, line := range bytes.Split(body, []byte("\n")) { - line = bytes.TrimRight(line, "\r") - data, ok := extractOpenAISSEDataLine(string(line)) - if !ok || data == "" || data == "[DONE]" { - continue - } - dataBytes := []byte(data) - s.parseSSEUsageBytes(dataBytes, &usage) - } + forEachOpenAISSEDataPayload(string(body), func(data []byte) { + s.parseSSEUsageBytes(data, &usage) + }) results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) if err != nil { return OpenAIUsage{}, 0, err @@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( format = "b64_json" } - reader := bufio.NewReader(resp.Body) usage := OpenAIUsage{} imageCount := 0 var firstTokenMs *int @@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( pendingSeen := make(map[string]struct{}) streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} var createdAt int64 + clientDisconnected := false + lastDownstreamWriteAt := time.Now() + var sseData openAISSEDataAccumulator + var processDataErr error + processDataDone := false + + processData := func(dataBytes []byte) { + if processDataDone || processDataErr != nil { + return + } + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, &usage) + if !gjson.ValidBytes(dataBytes) { + return + } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + switch gjson.GetBytes(dataBytes, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) + if b64 == "" { + return + } + eventName := streamPrefix + ".partial_image" + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + case "response.output_item.done": + img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) + if extractErr != nil { + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + processDataErr = extractErr + processDataDone = true + return + } + if !ok { + return + } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey(itemID, img) + if _, exists := emitted[key]; exists { + return + } + if _, exists := pendingSeen[key]; exists { + return + } + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) + case "response.completed": + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + if extractErr != nil { + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + processDataErr = extractErr + processDataDone = true + return + } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { + outputErr := fmt.Errorf("upstream did not return image output") + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error())) + processDataErr = outputErr + processDataDone = true + return + } + eventName := streamPrefix + ".completed" + for _, img := range finalResults { + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) + emitted[key] = struct{}{} + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + } + imageCount = len(emitted) + processDataDone = true + } + } + + processLine := func(line []byte) (bool, error) { + if len(line) == 0 { + return false, nil + } + sseData.AddLine(string(line), processData) + if processDataErr != nil { + return true, processDataErr + } + return processDataDone, nil + } + + flushData := func() (bool, error) { + sseData.Flush(processData) + if processDataErr != nil { + return true, processDataErr + } + return processDataDone, nil + } + + finalizePending := func() error { + if imageCount > 0 { + return nil + } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + emitted[key] = struct{}{} + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) + } + imageCount = len(emitted) + return nil + } + + streamErr := fmt.Errorf("stream disconnected before image generation completed") + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) + return streamErr + } + + streamInterval := s.openAIImageStreamDataInterval() + keepaliveInterval := s.openAIImageStreamKeepaliveInterval() + if streamInterval <= 0 && keepaliveInterval <= 0 { + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + done, processErr := processLine(line) + if processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } + if done { + return usage, imageCount, firstTokenMs, nil + } + if err == io.EOF { + break + } + if err != nil { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return usage, imageCount, firstTokenMs, err + } + } + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + if err := finalizePending(); err != nil { + return usage, imageCount, firstTokenMs, err + } + return usage, imageCount, firstTokenMs, nil + } + + type readEvent struct { + line []byte + err error + } + events := make(chan readEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev readEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + } + if len(line) > 0 && !sendEvent(readEvent{line: line}) { + return + } + if err == io.EOF { + return + } + if err != nil { + _ = sendEvent(readEvent{err: err}) + return + } + } + }() + defer close(done) + + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } for { - line, err := reader.ReadBytes('\n') - if len(line) > 0 { - trimmedLine := strings.TrimRight(string(line), "\r\n") - data, ok := extractOpenAISSEDataLine(trimmedLine) - if ok && data != "" && data != "[DONE]" { - if firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + select { + case ev, ok := <-events: + if !ok { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil } - dataBytes := []byte(data) - s.parseSSEUsageBytes(dataBytes, &usage) - if gjson.ValidBytes(dataBytes) { - if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { - mergeOpenAIResponsesImageMeta(&streamMeta, meta) - if eventCreatedAt > 0 { - createdAt = eventCreatedAt - } - } - switch gjson.GetBytes(dataBytes, "type").String() { - case "response.image_generation_call.partial_image": - b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) - if b64 != "" { - eventName := streamPrefix + ".partial_image" - partialMeta := streamMeta - mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ - OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), - Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), - }) - payload := buildOpenAIImagesStreamPartialPayload( - eventName, - b64, - gjson.GetBytes(dataBytes, "partial_image_index").Int(), - format, - createdAt, - partialMeta, - ) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr - } - } - case "response.output_item.done": - img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) - if extractErr != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, extractErr - } - if !ok { - break - } - mergeOpenAIResponsesImageMeta(&streamMeta, img) - mergeOpenAIResponsesImageMeta(&img, streamMeta) - key := openAIResponsesImageResultKey(itemID, img) - if _, exists := emitted[key]; exists { - break - } - if _, exists := pendingSeen[key]; exists { - break - } - pendingSeen[key] = struct{}{} - pendingResults = append(pendingResults, img) - case "response.completed": - results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) - if extractErr != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, extractErr - } - mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) - finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) - finalSeen := make(map[string]struct{}) - for _, img := range results { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) - } - for _, img := range pendingResults { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) - } - if len(finalResults) == 0 { - err = fmt.Errorf("upstream did not return image output") - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, err - } - eventName := streamPrefix + ".completed" - for _, img := range finalResults { - key := openAIResponsesImageResultKey("", img) - if _, exists := emitted[key]; exists { - continue - } - payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr - } - emitted[key] = struct{}{} - } - imageCount = len(emitted) - return usage, imageCount, firstTokenMs, nil - } + if err := finalizePending(); err != nil { + return usage, imageCount, firstTokenMs, err } + return usage, imageCount, firstTokenMs, nil } - } - if err == io.EOF { - break - } - if err != nil { - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, err - } - } - - if imageCount > 0 { - return usage, imageCount, firstTokenMs, nil - } - if len(pendingResults) > 0 { - eventName := streamPrefix + ".completed" - for _, img := range pendingResults { - mergeOpenAIResponsesImageMeta(&img, streamMeta) - key := openAIResponsesImageResultKey("", img) - if _, exists := emitted[key]; exists { + if ev.err != nil { + if done, processErr := flushData(); processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } else if done { + return usage, imageCount, firstTokenMs, nil + } + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error())) + return usage, imageCount, firstTokenMs, ev.err + } + done, processErr := processLine(ev.line) + if processErr != nil { + return usage, imageCount, firstTokenMs, processErr + } + if done { + return usage, imageCount, firstTokenMs, nil + } + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { continue } - payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) - if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { - return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + if clientDisconnected { + return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout") } - emitted[key] = struct{}{} + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval) + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) + return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout") + case <-keepaliveCh: + if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { + continue + } + if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing") + continue + } + flusher.Flush() + lastDownstreamWriteAt = time.Now() } - imageCount = len(emitted) - return usage, imageCount, firstTokenMs, nil } - - streamErr := fmt.Errorf("stream disconnected before image generation completed") - _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) - return OpenAIUsage{}, imageCount, firstTokenMs, streamErr } func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( @@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( ) } - token, _, err := s.GetAccessToken(ctx, account) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + defer releaseUpstreamCtx() + + token, _, err := s.GetAccessToken(upstreamCtx, account) if err != nil { return nil, err } @@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( } setOpsUpstreamRequestBody(c, responsesBody) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) if err != nil { return nil, err } @@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( Kind: "failover", Message: upstreamMsg, }) - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(upstreamCtx, resp, account) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } - return s.handleErrorResponse(ctx, resp, c, account, responsesBody) + return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody) } defer func() { _ = resp.Body.Close() }() @@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( if parsed.Stream { usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) if err != nil { + if imageCount > 0 { + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, err + } return nil, err } } else { diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 681e0e8e..fa4a4415 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -3,6 +3,7 @@ package service import ( "bytes" "context" + "errors" "io" "mime/multipart" "net/http" @@ -17,6 +18,20 @@ import ( "github.com/tidwall/gjson" ) +type failingOpenAIImageWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`) @@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + size string + wantTier string + }{ + {size: "1024x1024", wantTier: "1K"}, + {size: "1536x1024", wantTier: "2K"}, + {size: "1024x1536", wantTier: "2K"}, + {size: "2048x2048", wantTier: "2K"}, + {size: "2048x1152", wantTier: "2K"}, + {size: "3840x2160", wantTier: "4K"}, + {size: "2160x3840", wantTier: "4K"}, + {size: "1024X768", wantTier: "2K"}, + {size: "1280x768", wantTier: "2K"}, + {size: "2560x1440", wantTier: "2K"}, + {size: "2560x1600", wantTier: "4K"}, + {size: "auto", wantTier: "2K"}, + } + + svc := &OpenAIGatewayService{} + for _, tt := range tests { + t.Run(tt.size, func(t *testing.T) { + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, tt.size, parsed.Size) + require.Equal(t, tt.wantTier, parsed.SizeTier) + }) + } +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + size string + wantTier string + }{ + {size: "2048x1153", wantTier: "2K"}, + {size: "4096x1024", wantTier: "4K"}, + {size: "3840x1024", wantTier: "4K"}, + {size: "512x512", wantTier: "2K"}, + {size: "invalid", wantTier: "2K"}, + {size: "999999999999999999999999999x2", wantTier: "2K"}, + } + + svc := &OpenAIGatewayService{} + for _, tt := range tests { + t.Run(tt.size, func(t *testing.T) { + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, tt.size, parsed.Size) + require.Equal(t, tt.wantTier, parsed.SizeTier) + }) + } +} + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, "2048x1152", parsed.Size) + require.Equal(t, "2K", parsed.SizeTier) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { gin.SetMode(gin.TestMode) @@ -543,6 +652,57 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbac require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_multiline"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"image_generation.completed\",\n" + + "data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" + + "data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" + + "data: [DONE]\n\n", + )), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 18, result.Usage.OutputTokens) + require.Equal(t, 8, result.Usage.ImageOutputTokens) +} + func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) { body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`) @@ -686,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) } +func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageStreamDataIntervalTimeout: 1, + ImageStreamKeepaliveInterval: 0, + }, + }, + httpUpstream: &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_disconnect_apikey"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" + + "data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" + + "data: [DONE]\n\n", + )), + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + account := &Account{ + ID: 8, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-api-key", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.ImageOutputTokens) +} + func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { gin.SetMode(gin.TestMode) @@ -901,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi require.JSONEq(t, `{"images":1}`, string(usageRaw)) } +func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) { + body := []byte( + "data: {\"type\":\"response.completed\",\n" + + "data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + ) + + results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body) + require.NoError(t, err) + require.True(t, foundFinal) + require.Equal(t, int64(1710000010), createdAt) + require.Len(t, results, 1) + require.Equal(t, "ZmluYWw=", results[0].Result) + require.Equal(t, "png", firstMeta.OutputFormat) + require.JSONEq(t, `{"images":1}`, string(usageRaw)) +} + func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) @@ -957,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) require.NotContains(t, rec.Body.String(), "event: error") } + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + svc.httpUpstream = &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_multiline_oauth"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\n" + + "data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + account := &Account{ + ID: 11, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 10, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.ImageOutputTokens) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.NotContains(t, rec.Body.String(), "event: error") +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + ImageStreamDataIntervalTimeout: 1, + ImageStreamKeepaliveInterval: 0, + }, + }, + } + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_disconnect_oauth"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 9, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 9, result.Usage.OutputTokens) + require.Equal(t, 4, result.Usage.ImageOutputTokens) +} diff --git a/backend/internal/service/openai_model_alias.go b/backend/internal/service/openai_model_alias.go new file mode 100644 index 00000000..2fa2c90e --- /dev/null +++ b/backend/internal/service/openai_model_alias.go @@ -0,0 +1,137 @@ +package service + +import "strings" + +func lastOpenAIModelSegment(model string) string { + model = strings.TrimSpace(model) + if model == "" { + return "" + } + if strings.Contains(model, "/") { + parts := strings.Split(model, "/") + model = parts[len(parts)-1] + } + return strings.TrimSpace(model) +} + +func canonicalizeOpenAIModelAliasSpelling(model string) string { + model = strings.ToLower(lastOpenAIModelSegment(model)) + if model == "" { + return "" + } + + normalized := strings.ReplaceAll(model, "_", "-") + normalized = strings.Join(strings.Fields(normalized), "-") + for strings.Contains(normalized, "--") { + normalized = strings.ReplaceAll(normalized, "--", "-") + } + + if strings.HasPrefix(normalized, "gpt5") { + normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5") + } + if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") { + return "" + } + + replacements := []struct { + from string + to string + }{ + {"gpt-5.4mini", "gpt-5.4-mini"}, + {"gpt-5.4nano", "gpt-5.4-nano"}, + {"gpt-5.3-codexspark", "gpt-5.3-codex-spark"}, + {"gpt-5.3codexspark", "gpt-5.3-codex-spark"}, + {"gpt-5.3codex", "gpt-5.3-codex"}, + } + for _, replacement := range replacements { + normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to) + } + return normalized +} + +func normalizeKnownOpenAICodexModel(model string) string { + normalized := canonicalizeOpenAIModelAliasSpelling(model) + if normalized == "" { + return "" + } + + if mapped := getNormalizedCodexModel(normalized); mapped != "" { + return mapped + } + if strings.HasSuffix(normalized, "-openai-compact") { + if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" { + return mapped + } + } + + switch { + case strings.Contains(normalized, "gpt-5.5"): + return "gpt-5.5" + case strings.Contains(normalized, "gpt-5.4-mini"): + return "gpt-5.4-mini" + case strings.Contains(normalized, "gpt-5.4-nano"): + return "gpt-5.4-nano" + case strings.Contains(normalized, "gpt-5.4"): + return "gpt-5.4" + case strings.Contains(normalized, "gpt-5.2"): + return "gpt-5.2" + case strings.Contains(normalized, "gpt-5.3-codex-spark"): + return "gpt-5.3-codex-spark" + case strings.Contains(normalized, "gpt-5.3-codex"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "gpt-5.3"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "codex"): + return "gpt-5.3-codex" + case strings.Contains(normalized, "gpt-5"): + return "gpt-5.4" + default: + return "" + } +} + +func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return candidates + } + add := func(candidate string) { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return + } + key := strings.ToLower(candidate) + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + candidates = append(candidates, candidate) + } + + add(trimmed) + if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" { + add(canonical) + } + if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" { + add(normalized) + } + return candidates +} + +func usageBillingModelCandidates(primary string, alternates ...string) []string { + seen := make(map[string]struct{}, 1+len(alternates)) + candidates := appendUsageBillingModelCandidate(nil, seen, primary) + for _, alternate := range alternates { + candidates = appendUsageBillingModelCandidate(candidates, seen, alternate) + } + return candidates +} + +func firstUsageBillingModel(candidates []string) string { + for _, candidate := range candidates { + if trimmed := strings.TrimSpace(candidate); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 5c3e1ae0..f087ac32 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -94,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) { defaultMappedModel: "gpt-5.4", expectedModel: "gpt-5.5", }, + { + name: "preserves compact-spelled gpt5.5 instead of group default", + account: &Account{ + Credentials: map[string]any{}, + }, + requestedModel: "gpt5.5", + defaultMappedModel: "gpt-5.4", + expectedModel: "gpt5.5", + }, { name: "preserves openai namespaced gpt-5.5 instead of group default", account: &Account{ diff --git a/backend/internal/service/openai_sse_data.go b/backend/internal/service/openai_sse_data.go new file mode 100644 index 00000000..61b813b6 --- /dev/null +++ b/backend/internal/service/openai_sse_data.go @@ -0,0 +1,70 @@ +package service + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +type openAISSEDataAccumulator struct { + lines []string +} + +func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) { + if fn == nil { + return + } + trimmedLine := strings.TrimRight(line, "\r\n") + if data, ok := extractOpenAISSEDataLine(trimmedLine); ok { + a.lines = append(a.lines, data) + return + } + if strings.TrimSpace(trimmedLine) == "" { + a.Flush(fn) + } +} + +func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) { + if fn == nil || len(a.lines) == 0 { + return + } + emitOpenAISSEDataPayloads(a.lines, fn) + a.lines = a.lines[:0] +} + +func forEachOpenAISSEDataPayload(body string, fn func([]byte)) { + if fn == nil || strings.TrimSpace(body) == "" { + return + } + var acc openAISSEDataAccumulator + for _, line := range strings.Split(body, "\n") { + acc.AddLine(line, fn) + } + acc.Flush(fn) +} + +func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) { + if fn == nil || len(lines) == 0 { + return + } + if len(lines) == 1 { + emitOpenAISSEDataPayload(lines[0], fn) + return + } + joined := strings.Join(lines, "\n") + if gjson.Valid(joined) { + emitOpenAISSEDataPayload(joined, fn) + return + } + for _, line := range lines { + emitOpenAISSEDataPayload(line, fn) + } +} + +func emitOpenAISSEDataPayload(data string, fn func([]byte)) { + data = strings.TrimSpace(data) + if data == "" || data == "[DONE]" { + return + } + fn([]byte(data)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 201073e0..cb045ae7 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -1990,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int responseID := "" var finalResponse []byte @@ -2171,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) } + imageCounter.AddSSEData(message) if eventType == "error" { errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) @@ -2343,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( Usage: *usage, Model: originalModel, UpstreamModel: mappedModel, + ImageCount: imageCounter.Count(), ServiceTier: extractOpenAIServiceTier(reqBody), ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), Stream: reqStream, @@ -2449,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( promptCacheKey string previousResponseID string originalModel string + imageBillingModel string + imageSizeTier string payloadBytes int } @@ -2546,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } normalized = next } + imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized) + if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil) + } + imageBillingModel := "" + imageSizeTier := "" + if imageIntent { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel) + if imageCfgErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr) + } + } // Apply OpenAI Fast Policy on the response.create frame using the same // evaluator/normalize/scope rules as the HTTP entrypoints. This is the @@ -2591,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( promptCacheKey: promptCacheKey, previousResponseID: previousResponseID, originalModel: originalModel, + imageBillingModel: imageBillingModel, + imageSizeTier: imageSizeTier, payloadBytes: len(normalized), }, nil } @@ -2792,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return payload, nil } - sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") } @@ -2817,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( responseID := "" usage := OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() var firstTokenMs *int reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") @@ -2938,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) } + imageCounter.AddSSEData(upstreamMessage) if !clientDisconnected { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { @@ -2997,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( clientDisconnected, ) } - return &OpenAIForwardResult{ + imageCount := imageCounter.Count() + result := &OpenAIForwardResult{ RequestID: responseID, Usage: usage, Model: originalModel, @@ -3009,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ResponseHeaders: lease.HandshakeHeaders(), Duration: time.Since(turnStart), FirstTokenMs: firstTokenMs, - }, nil + } + if imageCount > 0 { + result.ImageCount = imageCount + result.ImageSize = imageSizeTier + result.BillingModel = imageBillingModel + } + return result, nil } } } currentPayload := firstPayload.payloadRaw currentOriginalModel := firstPayload.originalModel + currentImageBillingModel := firstPayload.imageBillingModel + currentImageSizeTier := firstPayload.imageSizeTier currentPayloadBytes := firstPayload.payloadBytes isStrictAffinityTurn := func(payload []byte) bool { if !storeDisabled { @@ -3460,7 +3491,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } - result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier) if relayErr != nil { lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { @@ -3582,6 +3613,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } currentPayload = nextPayload.payloadRaw currentOriginalModel = nextPayload.originalModel + currentImageBillingModel = nextPayload.imageBillingModel + currentImageSizeTier = nextPayload.imageSizeTier currentPayloadBytes = nextPayload.payloadBytes storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) if !storeDisabled { diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go index 7a76c385..cd816533 100644 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) } +func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + if err := conn.WriteJSON(map[string]any{ + "type": "response.output_item.done", + "item": map[string]any{ + "id": "ig_ws_1", + "type": "image_generation_call", + "result": "final-image", + }, + }); err != nil { + t.Errorf("write response.output_item.done failed: %v", err) + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_image_1", + "model": "gpt-5.4", + "output": []any{ + map[string]any{ + "id": "ig_ws_1", + "type": "image_generation_call", + "result": "final-image", + }, + }, + "usage": map[string]any{ + "input_tokens": 9, + "output_tokens": 4, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + groupID := int64(1010) + c.Set("api_key", &APIKey{ + GroupID: &groupID, + Group: &Group{ + ID: groupID, + AllowImageGeneration: true, + }, + }) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 10, + Name: "openai-ws-image", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_image_1", result.RequestID) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "1K", result.ImageSize) + require.Equal(t, "gpt-image-2", result.BillingModel) + require.Equal(t, 9, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.True(t, result.OpenAIWSMode) + require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String()) +} + func requestToJSONString(payload map[string]any) string { if len(payload) == 0 { return "{}" diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 91a02901..8a033710 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -625,6 +625,9 @@ func normalizeModelNameForPricing(model string) string { } model = strings.TrimLeft(model, "/") + if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" { + return canonical + } return model } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index e2bd7cf3..3c3e2c5b 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -98,6 +98,19 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T) require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12) } +func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) { + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": {InputCostPerToken: 1.25e-6}, + }, + } + + got := svc.GetModelPricing("openai/gpt5.5") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) +} + func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { svc := &PricingService{ pricingData: map[string]*LiteLLMModelPricing{ diff --git a/backend/migrations/134_image_generation_group_controls.sql b/backend/migrations/134_image_generation_group_controls.sql new file mode 100644 index 00000000..37941c00 --- /dev/null +++ b/backend/migrations/134_image_generation_group_controls.sql @@ -0,0 +1,26 @@ +-- 生图能力与图片倍率模式控制 +-- 兼容性原则: +-- 1. 不改写现有 image_price_1k/2k/4k,避免改变已配置分组的最终图片价格。 +-- 2. 现有 openai/gemini/antigravity 分组默认保持可生图,避免升级后中断已有图片业务。 +-- 3. 现有分组默认共享当前有效分组倍率,保持历史扣费公式。 + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS allow_image_generation BOOLEAN NOT NULL DEFAULT false; + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS image_rate_independent BOOLEAN NOT NULL DEFAULT false; + +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0; + +UPDATE groups +SET allow_image_generation = true +WHERE platform IN ('openai', 'gemini', 'antigravity'); + +UPDATE groups +SET image_rate_independent = false, + image_rate_multiplier = 1.0; + +COMMENT ON COLUMN groups.allow_image_generation IS '是否允许该分组使用图片生成能力'; +COMMENT ON COLUMN groups.image_rate_independent IS '图片生成是否使用独立倍率;false 表示共享分组有效倍率'; +COMMENT ON COLUMN groups.image_rate_multiplier IS '图片生成独立倍率,仅 image_rate_independent=true 时生效'; diff --git a/deploy/.env.example b/deploy/.env.example index e1eb8256..28205f7c 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -285,6 +285,25 @@ GATEWAY_SCHEDULING_OUTBOX_BACKLOG_REBUILD_ROWS=10000 # 全量重建周期(秒) GATEWAY_SCHEDULING_FULL_REBUILD_INTERVAL_SECONDS=300 +# ----------------------------------------------------------------------------- +# Image Generation Stream & Concurrency (Optional) +# 图片生成流式与并发隔离配置(可选) +# ----------------------------------------------------------------------------- +# 图片流式上游数据间隔超时(秒)。0 表示禁用;非 0 时必须为 60-1800。 +GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=900 +# 图片流式 keepalive 间隔(秒)。0 表示禁用;非 0 时必须为 5-60。 +GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=10 +# 是否启用进程级图片生成并发限制。默认 false,保持历史行为。 +GATEWAY_IMAGE_CONCURRENCY_ENABLED=false +# 当前进程允许同时处理的图片生成请求数。0 表示不限制。 +GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=0 +# 图片并发超限策略:reject 直接返回 429;wait 等待空闲槽位。 +GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=reject +# wait 模式下等待空闲图片槽位的最长时间(秒)。 +GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=30 +# wait 模式下当前进程允许排队等待的最大图片请求数。0 表示不允许等待队列。 +GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=100 + # ----------------------------------------------------------------------------- # Dashboard Aggregation (Optional) # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index dfc363b5..1670699b 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -340,6 +340,30 @@ gateway: # Stream keepalive interval (seconds), 0=disable # 流式 keepalive 间隔(秒),0=禁用 stream_keepalive_interval: 10 + # Image stream data interval timeout (seconds), 0=disable; independent from ordinary text streams + # 图片流数据间隔超时(秒),0=禁用;独立于普通文本流式 + image_stream_data_interval_timeout: 900 + # Image stream keepalive interval (seconds), 0=disable; independent from ordinary text streams + # 图片流式 keepalive 间隔(秒),0=禁用;独立于普通文本流式 + image_stream_keepalive_interval: 10 + # Image generation independent concurrency limiter (process-local, default disabled) + # 图片生成独立并发限制(进程级,默认关闭;多实例总上限约为实例数×该值) + image_concurrency: + # Enable image-only concurrency protection; false keeps existing behavior unchanged + # 是否启用图片独立并发保护;false 保持现有行为不变 + enabled: false + # Max concurrent image generation requests in this process, 0=unlimited + # 当前进程允许同时处理的图片生成请求数,0=不限制 + max_concurrent_requests: 0 + # Overflow mode when the image concurrency limit is full: reject/wait + # 图片并发满时的处理方式:reject=立即拒绝,wait=等待槽位 + overflow_mode: "reject" + # Wait timeout for overflow_mode=wait (seconds), 0=do not wait + # wait 模式等待图片并发槽位的超时时间(秒),0=不等待 + wait_timeout_seconds: 30 + # Max image requests waiting in this process when overflow_mode=wait, 0=unlimited + # wait 模式当前进程允许排队等待的图片请求数,0=不限制 + max_waiting_requests: 100 # SSE max line size in bytes (default: 40MB) # SSE 单行最大字节数(默认 40MB) max_line_size: 41943040 diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index 7793e424..b7a805b5 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -40,6 +40,13 @@ services: - JWT_SECRET=${JWT_SECRET:-} - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} - TZ=${TZ:-Asia/Shanghai} + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index 5aea78fb..51a80227 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -146,6 +146,17 @@ services: # Proxy for accessing GitHub (online updates + pricing data) # Examples: http://host:port, socks5://host:port - UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-} + + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml index df0ccfcc..438d0a8a 100644 --- a/deploy/docker-compose.standalone.yml +++ b/deploy/docker-compose.standalone.yml @@ -93,6 +93,17 @@ services: # SECURITY: This repo does not embed third-party client_secret. - GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-} - ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-} + + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} healthcheck: test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"] interval: 30s diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 3a714260..1d639ea4 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -142,6 +142,17 @@ services: # Proxy for accessing GitHub (online updates + pricing data) # Examples: http://host:port, socks5://host:port - UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-} + + # ======================================================================= + # Image Generation Stream & Concurrency + # ======================================================================= + - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900} + - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10} + - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false} + - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0} + - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject} + - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30} + - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100} depends_on: postgres: condition: service_healthy diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index adcb3cc6..629e6aa2 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -291,9 +291,23 @@ +
- {{ tooltipData.billing_mode === BILLING_MODE_IMAGE ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} - ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }} + {{ t('usage.unitPrice') }} + ${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} @@ -360,6 +374,13 @@ function accountBilled(row: { total_cost?: number | null; account_stats_cost?: n return Number.isNaN(result) ? 0 : result } +function imageUnitPrice(row: AdminUsageLog | null): number { + if (!row || row.image_count <= 0) return 0 + const total = row.total_cost ?? 0 + const price = total / row.image_count + return Number.isFinite(price) ? price : 0 +} + import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 195d0237..2dca418c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -844,6 +844,8 @@ export default { perMillionTokens: '/ 1M tokens', unitPrice: 'Per-request price', imageUnitPrice: 'Per-image price', + imageTotalPrice: 'Image total price', + imageCount: 'Image count', cacheRead: 'Read', cacheWrite: 'Write', serviceTier: 'Service tier', @@ -2050,7 +2052,13 @@ export default { }, imagePricing: { title: 'Image Generation Pricing', - description: 'Configure pricing for image generation models. Leave empty to use default prices.' + description: 'Configure image generation access and base image prices. Leave empty to use default prices.', + allowImageGeneration: 'Allow image generation for this group', + independentMultiplier: 'Use independent image multiplier', + imageMultiplier: 'Image multiplier', + modeHint: 'By default, image billing uses image price × current effective group multiplier. Independent mode uses image price × image multiplier.', + finalPricePreview: 'Final per-image price preview', + notConfigured: 'Not configured' }, claudeCode: { title: 'Claude Code Client Restriction', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 0f95d652..b1217793 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -848,6 +848,8 @@ export default { perMillionTokens: '/ 1M Token', unitPrice: '单次价格', imageUnitPrice: '单张价格', + imageTotalPrice: '图片总价', + imageCount: '图片张数', cacheRead: '读取', cacheWrite: '写入', serviceTier: '服务档位', @@ -2133,7 +2135,13 @@ export default { }, imagePricing: { title: '图片生成计费', - description: '配置图片生成模型的图片生成价格,留空则使用默认价格' + description: '配置图片生成能力和图片基础单价,留空则使用默认价格', + allowImageGeneration: '允许当前分组生图', + independentMultiplier: '生图倍率独立', + imageMultiplier: '生图独立倍率', + modeHint: '默认关闭独立倍率时,图片费用 = 图片价格 × 当前分组有效倍率;开启独立倍率后,图片费用 = 图片价格 × 生图独立倍率。', + finalPricePreview: '最终单张价格预览', + notConfigured: '未配置' }, claudeCode: { title: 'Claude Code 客户端限制', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 479a8d95..79530c99 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -492,7 +492,10 @@ export interface Group { daily_limit_usd: number | null weekly_limit_usd: number | null monthly_limit_usd: number | null - // 图片生成计费配置(仅 antigravity 平台使用) + // 图片生成计费配置 + allow_image_generation: boolean + image_rate_independent: boolean + image_rate_multiplier: number image_price_1k: number | null image_price_2k: number | null image_price_4k: number | null @@ -602,6 +605,9 @@ export interface CreateGroupRequest { daily_limit_usd?: number | null weekly_limit_usd?: number | null monthly_limit_usd?: number | null + allow_image_generation?: boolean + image_rate_independent?: boolean + image_rate_multiplier?: number image_price_1k?: number | null image_price_2k?: number | null image_price_4k?: number | null @@ -627,6 +633,9 @@ export interface UpdateGroupRequest { daily_limit_usd?: number | null weekly_limit_usd?: number | null monthly_limit_usd?: number | null + allow_image_generation?: boolean + image_rate_independent?: boolean + image_rate_multiplier?: number image_price_1k?: number | null image_price_2k?: number | null image_price_4k?: number | null diff --git a/frontend/src/views/admin/GroupsView.vue b/frontend/src/views/admin/GroupsView.vue index d24a3a11..753d52dd 100644 --- a/frontend/src/views/admin/GroupsView.vue +++ b/frontend/src/views/admin/GroupsView.vue @@ -666,6 +666,40 @@

{{ t("admin.groups.imagePricing.description") }}

+
+ + +
+
+ + +
@@ -701,6 +735,22 @@ />
+

+ {{ t("admin.groups.imagePricing.modeHint") }} +

+
+
+ {{ t("admin.groups.imagePricing.finalPricePreview") }} +
+
+
+ {{ item.label }}: {{ item.value }} +
+
+
@@ -1801,6 +1851,40 @@

{{ t("admin.groups.imagePricing.description") }}

+
+ + +
+
+ + +
@@ -1836,6 +1920,22 @@ />
+

+ {{ t("admin.groups.imagePricing.modeHint") }} +

+
+
+ {{ t("admin.groups.imagePricing.finalPricePreview") }} +
+
+
+ {{ item.label }}: {{ item.value }} +
+
+
@@ -3009,7 +3109,10 @@ const createForm = reactive({ daily_limit_usd: null as number | null, weekly_limit_usd: null as number | null, monthly_limit_usd: null as number | null, - // 图片生成计费配置(仅 antigravity 平台使用) + // 图片生成计费配置 + allow_image_generation: false, + image_rate_independent: false, + image_rate_multiplier: 1, image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, @@ -3291,7 +3394,10 @@ const editForm = reactive({ daily_limit_usd: null as number | null, weekly_limit_usd: null as number | null, monthly_limit_usd: null as number | null, - // 图片生成计费配置(仅 antigravity 平台使用) + // 图片生成计费配置 + allow_image_generation: false, + image_rate_independent: false, + image_rate_multiplier: 1, image_price_1k: null as number | null, image_price_2k: null as number | null, image_price_4k: null as number | null, @@ -3321,6 +3427,62 @@ const editForm = reactive({ rpm_limit: 0 as number, }); +type ImagePricingFormState = { + rate_multiplier: number; + image_rate_independent: boolean; + image_rate_multiplier: number; + image_price_1k: number | string | null; + image_price_2k: number | string | null; + image_price_4k: number | string | null; +}; + +const imagePricingTiers = [ + { key: "image_price_1k", label: "1K" }, + { key: "image_price_2k", label: "2K" }, + { key: "image_price_4k", label: "4K" }, +] as const; + +const normalizePreviewNumber = (value: number | string | null | undefined, fallback = 0) => { + if (value === null || value === undefined || value === "") { + return fallback; + } + const parsed = Number(value); + return Number.isFinite(parsed) ? parsed : fallback; +}; + +const formatImagePricePreview = (value: number | string | null | undefined) => { + if (value === null || value === undefined || value === "") { + return t("admin.groups.imagePricing.notConfigured"); + } + const price = Number(value); + if (!Number.isFinite(price) || price < 0) { + return t("admin.groups.imagePricing.notConfigured"); + } + return `$${price.toFixed(6).replace(/0+$/, "").replace(/\.$/, "")}`; +}; + +const buildImageFinalPricePreview = (form: ImagePricingFormState) => { + const multiplier = form.image_rate_independent + ? normalizePreviewNumber(form.image_rate_multiplier, 1) + : normalizePreviewNumber(form.rate_multiplier, 1); + return imagePricingTiers.map((tier) => { + const basePrice = normalizePreviewNumber(form[tier.key]); + return { + label: tier.label, + value: basePrice > 0 + ? formatImagePricePreview(basePrice * multiplier) + : t("admin.groups.imagePricing.notConfigured"), + }; + }); +}; + +const createImageFinalPricePreview = computed(() => + buildImageFinalPricePreview(createForm), +); +const editImageFinalPricePreview = computed(() => + buildImageFinalPricePreview(editForm), +); + // 根据分组类型返回不同的删除确认消息 const deleteConfirmMessage = computed(() => { if (!deletingGroup.value) { @@ -3479,6 +3641,9 @@ const closeCreateModal = () => { createForm.daily_limit_usd = null; createForm.weekly_limit_usd = null; createForm.monthly_limit_usd = null; + createForm.allow_image_generation = false; + createForm.image_rate_independent = false; + createForm.image_rate_multiplier = 1; createForm.image_price_1k = null; createForm.image_price_2k = null; createForm.image_price_4k = null; @@ -3513,6 +3678,16 @@ const normalizeOptionalLimit = ( return Number.isFinite(value) && value > 0 ? value : null; }; +const normalizeImageRateMultiplier = ( + value: number | string | null | undefined, +): number => { + if (value === null || value === undefined || value === "") { + return 1; + } + const parsed = Number(value); + return Number.isFinite(parsed) && parsed >= 0 ? parsed : 1; +}; + const handleCreateGroup = async () => { if (!createForm.name.trim()) { appStore.showError(t("admin.groups.nameRequired")); @@ -3551,6 +3726,9 @@ const handleCreateGroup = async () => { requestData.daily_limit_usd = emptyToNull(requestData.daily_limit_usd); requestData.weekly_limit_usd = emptyToNull(requestData.weekly_limit_usd); requestData.monthly_limit_usd = emptyToNull(requestData.monthly_limit_usd); + requestData.image_rate_multiplier = normalizeImageRateMultiplier( + requestData.image_rate_multiplier, + ); await adminAPI.groups.create(requestData); appStore.showSuccess(t("admin.groups.groupCreated")); closeCreateModal(); @@ -3582,6 +3760,9 @@ const handleEdit = async (group: AdminGroup) => { editForm.daily_limit_usd = group.daily_limit_usd; editForm.weekly_limit_usd = group.weekly_limit_usd; editForm.monthly_limit_usd = group.monthly_limit_usd; + editForm.allow_image_generation = group.allow_image_generation ?? false; + editForm.image_rate_independent = group.image_rate_independent ?? false; + editForm.image_rate_multiplier = group.image_rate_multiplier ?? 1; editForm.image_price_1k = group.image_price_1k; editForm.image_price_2k = group.image_price_2k; editForm.image_price_4k = group.image_price_4k; @@ -3676,6 +3857,9 @@ const handleUpdateGroup = async () => { payload.daily_limit_usd = emptyToNull(payload.daily_limit_usd); payload.weekly_limit_usd = emptyToNull(payload.weekly_limit_usd); payload.monthly_limit_usd = emptyToNull(payload.monthly_limit_usd); + payload.image_rate_multiplier = normalizeImageRateMultiplier( + payload.image_rate_multiplier, + ); await adminAPI.groups.update(editingGroup.value.id, payload); appStore.showSuccess(t("admin.groups.groupUpdated")); closeEditModal(); diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index 08298d89..0eb8d455 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -459,9 +459,23 @@ +
- {{ tooltipData.billing_mode === 'image' ? t('usage.imageUnitPrice') : t('usage.unitPrice') }} - ${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }} + {{ t('usage.unitPrice') }} + ${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
{{ t('admin.usage.cacheCreationCost') }} @@ -625,6 +639,13 @@ const formatDuration = (ms: number): string => { return `${(ms / 1000).toFixed(2)}s` } +const imageUnitPrice = (row: UsageLog | null): number => { + if (!row || row.image_count <= 0) return 0 + const total = row.total_cost ?? 0 + const price = total / row.image_count + return Number.isFinite(price) ? price : 0 +} + const formatUserAgent = (ua: string): string => { return ua } diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index b71f9d58..38770704 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -44,7 +44,6 @@ export default defineConfig(({ mode }) => { plugins: [ vue(), checker({ - typescript: true, vueTsc: true }), injectPublicSettings(backendUrl) diff --git a/openspec/changes/add-image-generation-billing-controls/.openspec.yaml b/openspec/changes/add-image-generation-billing-controls/.openspec.yaml new file mode 100644 index 00000000..5f23b852 --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-04-29 diff --git a/openspec/changes/add-image-generation-billing-controls/design.md b/openspec/changes/add-image-generation-billing-controls/design.md new file mode 100644 index 00000000..21ec1c10 --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/design.md @@ -0,0 +1,227 @@ +## Context + +当前代码已经具备图片价格字段和部分图片转发能力,但边界不完整: + +- `backend/ent/schema/group.go` 只有 `rate_multiplier` 和 `image_price_1k/2k/4k`,没有分组级生图能力开关,也没有“图片是否共享分组倍率”的开关。 +- `backend/internal/handler/openai_images.go` 在解析 `/v1/images/*` 后只做通用余额/订阅资格检查,没有检查分组是否允许生图。 +- `backend/internal/service/openai_gateway_service.go` 对 Codex CLI 会自动注入 `image_generation` tool;通用 `/v1/responses` 只记录日志,没有把图片工具产物数量写入 `OpenAIForwardResult.ImageCount`。 +- `backend/internal/service/billing_service.go` 的 `CalculateImageCost` 当前使用 `image_price_* * image_count * rate_multiplier`。这个行为本身可以作为默认兼容模式,但普通编码分组 `rate_multiplier=0.15` 且希望图片最终价为 `0.2/张` 时,管理员必须填写 `image_price=0.2/0.15`,不可读且不适合长期运营。 +- `backend/internal/service/openai_gateway_service.go` 和 `backend/internal/service/gateway_service.go` 的渠道图片计费路径当前传 `RequestCount: 1`,多图请求会按 1 次收费。 +- `backend/internal/service/openai_images.go` 的 OpenAI 图片尺寸分层此前只覆盖少量固定尺寸;`gpt-image-2` 官方文档已经支持满足约束的自定义 `size`,因此本地计费必须能够对未知尺寸做稳定分档,同时不能因为本地映射不认识就提前拦截请求。 + +用户澄清后的业务要求是:普通编码分组可以关闭生图,也可以开启生图;开启后默认继续共享现有分组倍率以保持兼容,但管理员可以打开“生图倍率独立”开关,改用单独的图片倍率输入框。图片分组是推荐的运营隔离方式,但不是唯一承载方式。 + +## Goals / Non-Goals + +**Goals:** +- 分组具备明确的 `allow_image_generation` 开关,所有已知生图入口在调度上游前执行同一个权限判断。 +- 分组具备“生图倍率是否独立”的开关;默认 `false`,即共享当前代码里的有效分组倍率。 +- 生图倍率独立开关打开后,图片费用使用单独的 `image_rate_multiplier`,不再使用普通编码分组的倍率。 +- 保留现有 `image_price_1k/2k/4k` 字段作为图片单价配置,不强制把它们迁移成新的语义。 +- 普通编码分组在 `allow_image_generation=false` 时仍可正常使用 `gpt-5.4` / `gpt-5.5` 文本能力,但不能使用图片工具。 +- 普通编码分组在 `allow_image_generation=true` 时可使用 `gpt-5.4` / `gpt-5.5 + image_generation`,且按实际图片数量收费。 +- 通用 `/v1/responses`、OpenAI Images API、流式、非流式、透传路径全部把成功产出的图片数量写入 `ImageCount`。 +- 渠道 `billing_mode=image` 使用真实 `ImageCount`,不再固定按 1 次收费。 + +**Non-Goals:** +- 不引入新的第三方依赖。 +- 不改变 OpenAI 上游协议;只在现有请求转发、响应解析和计费归因层补齐控制。 +- 不把“图片分组”做成唯一安全边界;分组开关和图片计费逻辑必须适用于任意开启生图的分组。 +- 不在本变更中实现预扣费/资金冻结;失败请求仍不收费,成功请求按实际产物后扣费。 +- 不改变默认历史图片价格行为;默认共享现有有效倍率,历史 `图片价格 * 分组/用户有效倍率` 的扣费行为保持。 +- 不在本变更中新增用户级图片独立倍率覆盖;用户专属普通倍率只在共享倍率模式下继续影响图片。 + +## Decisions + +### 0. 兼容性优先原则 + +本变更的默认行为必须以“不改变现有已配置分组的最终扣费”为优先级: + +- 迁移不修改现有 `image_price_1k/2k/4k`。 +- 迁移把所有现有分组设置为 `image_rate_independent=false`,因此现有图片路径继续使用当前有效分组倍率。 +- 管理员不传新字段更新分组时,不得覆盖已保存的 `allow_image_generation`、`image_rate_independent`、`image_rate_multiplier`。 +- 前端编辑旧分组时必须回显服务端值;不能因为表单默认值把旧分组从共享倍率误改成独立倍率,或把允许生图误改成禁止生图。 +- 只有管理员显式打开 `image_rate_independent` 后,图片扣费才从共享倍率切换到图片独立倍率。 + +### 1. 分组字段与迁移策略 + +新增三个分组字段,对应“2 个开关 + 1 个输入框”: + +- `allow_image_generation BOOLEAN NOT NULL DEFAULT false` +- `image_rate_independent BOOLEAN NOT NULL DEFAULT false` +- `image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0` + +字段语义: + +- `allow_image_generation`:是否支持当前分组生图。 +- `image_rate_independent=false`:图片计费共享当前普通计费链路里的有效倍率,即当前 `userGroupRateResolver.Resolve(ctx, user.ID, groupID, group.RateMultiplier)` 得到的倍率;这保持现有行为。 +- `image_rate_independent=true`:图片计费使用 `group.image_rate_multiplier`;普通编码的 `rate_multiplier` 和用户专属普通倍率不参与图片扣费。 +- `image_price_1k/2k/4k`:继续表示图片基础单价,由选中的图片倍率模式继续相乘。 + +新建分组默认 `allow_image_generation=false`,避免新普通编码分组意外获得生图能力。为避免升级后立即打断已有图片业务,迁移对现有 `openai`、`gemini`、`antigravity` 分组回填 `allow_image_generation=true`,`anthropic` 分组保持 `false`。该回填只是兼容现状;上线后管理员必须按业务策略关闭不允许生图的普通编码分组。 + +迁移不改写已有 `image_price_1k/2k/4k`,并将所有现有分组设为 `image_rate_independent=false`、`image_rate_multiplier=1`。这样现有最终扣费公式保持不变: + +```text +历史/默认模式图片最终扣费 = image_price_* * image_count * 当前有效分组倍率 +``` + +普通编码分组 `rate_multiplier=0.15` 且希望图片 1K 最终扣费 `0.2/张` 时,管理员不再需要填写 `0.2/0.15`,而是设置: + +```text +image_rate_independent = true +image_rate_multiplier = 1 +image_price_1k = 0.2 +``` + +如果希望图片也打折,例如图片标价 `0.2/张`、图片折扣 `0.8`,则设置: + +```text +image_rate_independent = true +image_rate_multiplier = 0.8 +image_price_1k = 0.2 +``` + +### 2. 生图意图统一识别 + +新增一个服务层 helper,输入至少包含 endpoint、请求模型、请求体,输出是否为生图意图: + +```text +isImageGenerationIntent = + endpoint 是 /v1/images/generations 或 /v1/images/edits + OR requested model 以 gpt-image- 开头 + OR tools[] 存在 type == image_generation + OR tool_choice 显式指向 image_generation +``` + +生图意图判断必须在请求体被 Codex 注入、模型改写、渠道映射改写之前执行一次,并在这些改写之后再对最终请求体执行一次。原因是当前代码会在 `backend/internal/service/openai_gateway_service.go` 中注入 `image_generation` tool,也会在 `normalizeOpenAIResponsesImageOnlyModel` 中把 `gpt-image-*` 改写为文本模型 + 图片工具;只检查改写前或只检查改写后都可能漏掉场景。 + +`tool_choice` 判断只把明确指向 `image_generation` 的值视为生图意图;`auto`、`none`、`required` 本身不构成生图意图,但如果 `tools[]` 中存在 `image_generation`,仍由 `tools[]` 规则命中。 + +该判断必须在以下位置使用: + +- `/v1/images/*` handler 解析请求后、账号调度前。 +- `/v1/responses` 解析 body 后、Codex 自动注入 `image_generation` tool 前。 +- `normalizeOpenAIResponsesImageOnlyModel` 把 `gpt-image-*` 改写为 Responses 文本模型前。 +- OpenAI 高级 scheduler 入口保留现有账号能力检查,同时补齐渠道 restriction 检查,避免启用高级调度时绕过渠道模型限制。 + +当 `allow_image_generation=false` 时: + +- 显式生图意图返回 HTTP 403,错误类型使用现有 `permission_error` 风格。 +- Codex CLI 请求不自动注入 `image_generation` tool,也不追加图片桥接指令;如果请求没有显式生图意图,则继续按普通文本请求处理。 + +### 3. gpt-5.4 / gpt-5.5 生图承载方式 + +`gpt-5.4` / `gpt-5.5` 生图通过现有 OpenAI Responses API 的 `image_generation` tool 承载,不新增专用 endpoint: + +```json +{ + "model": "gpt-5.4", + "input": "生成一张图片", + "tools": [ + { + "type": "image_generation", + "model": "gpt-image-2", + "size": "1024x1024", + "output_format": "png" + } + ], + "tool_choice": { "type": "image_generation" } +} +``` + +`model=gpt-image-*` 发到 `/v1/responses` 时保留现有改写方向:主模型改为 Responses 文本模型,图片模型放入 `image_generation` tool。计费时如果能从工具配置得到 `gpt-image-*`,图片默认价格按该图片模型解析;如果工具未指定图片模型,则使用当前转发结果的 billing model,并优先使用分组/渠道配置价格。 + +### 4. 图片数量归因 + +新增统一图片输出解析 helper,返回去重后的图片数量和可用图片元信息。必须覆盖以下已有或可借鉴的事件形态: + +- 非流式 Responses JSON:`output[]` 中 `type == image_generation_call` 且 `result` 非空。 +- Responses SSE:`response.output_item.done` 中 `item.type == image_generation_call` 且 `item.result` 非空。 +- Responses SSE 完成事件:`response.completed.response.output[]` 中图片工具结果。 +- Images API 非流式:顶层 `data[]`。 +- Images API 流式:顶层 `data[]`、`image_generation.completed`、`response.output_item.done`、`response.completed`。 + +去重键按优先级使用 `item.id`、`call_id`、`result` 内容 hash。只统计最终图片,不统计 `partial_image`。 + +`openaiStreamingResult` 增加 `imageCount`、`imageSize`、`imageBillingModel`。`handleStreamingResponse`、`handleStreamingResponsePassthrough`、`handleNonStreamingResponse`、`handleNonStreamingResponsePassthrough` 都必须把解析结果带回 `OpenAIForwardResult`。当 `ImageCount > 0` 时,即使上游 usage 为 0,也必须写 usage log 并进入图片计费。 + +### 5. 图片价格公式 + +图片计费先确定单价,再确定倍率: + +```text +unit_price = 渠道 image 模式价格 或 分组 image_price_* 或 默认图片价格 +image_multiplier = + 如果 group.image_rate_independent == true: group.image_rate_multiplier + 否则: 当前有效分组倍率 +total_cost = unit_price * image_count +actual_cost = total_cost * image_multiplier +``` + +“当前有效分组倍率”必须沿用当前代码的倍率解析方式:默认配置倍率 → 分组 `rate_multiplier` → 用户专属分组倍率覆盖。这样 `image_rate_independent=false` 时完全保留当前行为。 + +`billing_mode=image` 的渠道价格是图片单价来源之一,仍优先于分组图片价格。图片渠道价格也必须按 `ImageCount` 计数,并使用同一套 `image_multiplier` 选择逻辑。 + +`billing_mode=per_request` 的非图片请求保持当前普通按次语义,继续使用普通 token 倍率;只有已经识别为图片请求且 `ImageCount > 0` 的路径使用图片计费逻辑。 + +`usage_logs.rate_multiplier` 继续表示“本次扣费实际使用的倍率”。因此: + +- token 日志记录普通 token 有效倍率。 +- image 日志在共享模式记录普通有效倍率。 +- image 日志在独立模式记录 `image_rate_multiplier`。 + +专用 `/v1/images/*` 仍按图片请求语义计费:当 `ImageCount > 0` 时,图片价格决定费用,伴随的上游 token usage 只记录不额外计 token 费用。这保持当前 Images API 的行为。 + +通用 `/v1/responses + image_generation` 的混合文本+图片输出存在一个明确取舍:如果继续沿用“`ImageCount > 0` 时只按图片计费”的当前计费分支,用户可以在一次图片请求中夹带大量文本输出而只付图片费用;如果改成“图片费用 + 非图片 token 费用”,会改变当前 `billing_mode=image` 的单一计费语义,并可能让渠道图片单价不再是全包价格。本变更为最大兼容性不引入混合计费模式,但必须在 usage log 中完整记录 token 与 image_count,便于后续按数据决定是否新增 `image_plus_token` 计费模式。 + +### 6. 尺寸档位与参数透传 + +OpenAI 图片请求的 `size` 参数必须透传给上游;本地只做计费分档,不做 OpenAI 尺寸合法性校验。无论尺寸是否满足官方约束,本地都不能因为未知尺寸或 provider-invalid 尺寸返回 400;如果上游不接受该尺寸,由上游响应错误。 + +官方 `gpt-image-2` 文档给出的常用尺寸与约束是本地计费分档的依据: + +- 常用尺寸:`1024x1024`、`1536x1024`、`1024x1536`、`2048x2048`、`2048x1152`、`3840x2160`、`2160x3840`、`auto`。 +- 自定义尺寸:官方支持满足约束的任意 `size`,包括边长、16 像素倍数、长短边比例、总像素范围等约束。 +- `2560x1440` 是 2K/QHD 参考边界;超过 `2560x1440` 总像素的输出进入更高档位风险区。 + +OpenAI 图片尺寸分层必须按以下规则: + +```text +empty, auto => 2K +1024x1024 => 1K +1536x1024, 1024x1536 => 2K +1792x1024, 1024x1792 => 2K +2048x2048, 2048x1152, 1152x2048 => 2K +3840x2160, 2160x3840 => 4K +未知且无法解析为正整数 WIDTHxHEIGHT => 2K +未知且 WIDTH * HEIGHT <= 2560*1440 => 2K +未知且 WIDTH * HEIGHT > 2560*1440 => 4K +``` + +这个规则只决定 `ImageSize` 和扣费档位,不修改请求体,不删除未知参数,不把未知尺寸改写成预设尺寸。 + +## Risks / Trade-offs + +- 历史普通编码分组迁移后仍默认允许生图 → 通过管理员可见开关、上线核对清单和新建分组默认关闭来控制;代码无法可靠判断“普通编码分组”和“图片分组”的业务意图。 +- 默认共享现有有效倍率仍保留“图片最终价不直观”的问题 → 这是兼容性选择;需要直观设置图片最终价的分组必须打开 `image_rate_independent`。 +- 独立图片倍率不会读取用户专属普通倍率 → 这是目标行为;如需要用户级图片独立倍率,应作为后续独立需求实现。 +- 通用 Responses 图片工具可能同时输出文本和图片 → 本变更默认仍按图片请求语义计费并完整记录 token;若业务要求文本也收费,应新增独立的混合计费模式,不能混入本次兼容性变更。 +- 本地不再拦截未知或 provider-invalid OpenAI 尺寸 → 非法尺寸会消耗一次上游请求失败成本和用户体验往返,但这是为了保证参数透传、兼容官方新增尺寸和第三方兼容提供商;计费只在成功产出最终图片后发生。 +- Responses 流式解析需要在客户端断开后继续 drain 上游以完成计费 → 沿用当前流式处理“客户端断开后继续读取上游用于计费”的模式,并只新增轻量 JSON 路径提取。 +- 预扣费不在本变更中实现 → 继续使用现有成功后扣费模型,避免失败请求退款、流式中断退款和图片数量未确定时预估错误。 + +## Migration Plan + +1. 新增数据库迁移,添加 `groups.allow_image_generation`、`groups.image_rate_independent` 和 `groups.image_rate_multiplier`。 +2. 回填现有分组:`openai`、`gemini`、`antigravity` 的 `allow_image_generation=true`,`anthropic=false`;所有现有分组 `image_rate_independent=false`、`image_rate_multiplier=1`。 +3. 不改写现有 `image_price_1k/2k/4k`,保持默认共享倍率模式下的历史扣费结果。 +4. 更新 Ent schema 与生成代码,更新后端 service/handler DTO 和前端类型。 +5. 先接入权限判断,确保未开启生图的分组不会到达上游。 +6. 再接入图片数量解析和图片计费倍率选择,确保开启生图的分组按图片数量收费。 +7. 最后更新前端管理界面、i18n、文档和测试。 +8. 回滚时只能通过新迁移回滚字段行为;不能修改已应用迁移文件。 + +## Open Questions + +无。当前方案不依赖未确认的上游新尺寸、新模型或新 endpoint。 diff --git a/openspec/changes/add-image-generation-billing-controls/proposal.md b/openspec/changes/add-image-generation-billing-controls/proposal.md new file mode 100644 index 00000000..1c19753f --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/proposal.md @@ -0,0 +1,29 @@ +## Why + +当前代码把“能否生图”和“如何按图片收费”混在模型、分组倍率、渠道定价与 Responses 工具调用里,导致 OpenAI 普通编码分组在允许 `gpt-5.4` / `gpt-5.5` 时也能通过 `image_generation` tool 产图,并且通用 `/v1/responses` 产图不会稳定写入 `ImageCount`。需要把生图能力、图片倍率模式、图片产出数量归因拆成独立能力,保证普通编码分组可按业务开关生图,开启后既能沿用现有倍率行为,也能按需切换到图片独立倍率。 + +## What Changes + +- 新增分组级生图能力开关,明确控制 `/v1/images/*`、`gpt-image-*`、显式 `image_generation` tool、Codex 自动注入图片工具等所有生图入口。 +- 新增分组级图片倍率模式开关,默认继续共享现有分组有效倍率;打开独立模式后使用图片独立倍率输入框。 +- 保留现有 `image_price_1k/2k/4k` 图片价格配置;图片最终扣费由“图片价格 × 当前倍率模式选出的倍率 × 图片数量”决定。 +- 统一统计 OpenAI Responses 图片工具产物数量,使 `gpt-5.4` / `gpt-5.5` 通过 `image_generation` tool 产图时进入图片计费,而不是退化成普通 token 计费或无 usage 时不计费。 +- 修正专用 Images API 与渠道图片计费场景,按实际图片数量和明确尺寸档位计费,避免固定 `RequestCount=1` 或未知尺寸静默落到 `2K`。 +- 更新后台分组配置、前端类型、使用说明和测试,覆盖普通编码分组关闭生图、普通编码分组开启生图、独立图片分组承载、生图流式/非流式等场景。 + +## Capabilities + +### New Capabilities +- `image-generation-access-control`: 定义分组级生图能力开关、所有生图意图识别规则、拒绝行为与 Codex 自动注入规则。 +- `image-generation-billing-accounting`: 定义图片倍率模式、图片数量归因、尺寸档位、渠道图片价格和用量日志要求。 + +### Modified Capabilities +- 无。 + +## Impact + +- Backend schema/API: `backend/ent/schema/group.go`、Ent 生成代码、数据库迁移、管理员分组 create/update/list DTO、分组缓存/序列化。 +- Backend request gates: `backend/internal/handler/openai_images.go`、`backend/internal/service/openai_gateway_service.go`、`backend/internal/service/openai_codex_transform.go`、OpenAI account scheduler 相关模型/图片能力调度入口。 +- Backend billing: `backend/internal/service/billing_service.go`、`backend/internal/service/openai_gateway_service.go`、`backend/internal/service/gateway_service.go`、usage log 与 account stats 成本计算路径。 +- Frontend admin: `frontend/src/types/index.ts`、`frontend/src/views/admin/GroupsView.vue`、相关 i18n 文案与图片计费展示。 +- Tests: OpenAI Images API、OpenAI Responses stream/non-stream/passthrough、分组开关、图片倍率模式、渠道图片计数、尺寸档位与 usage log 断言。 diff --git a/openspec/changes/add-image-generation-billing-controls/specs/image-generation-access-control/spec.md b/openspec/changes/add-image-generation-billing-controls/specs/image-generation-access-control/spec.md new file mode 100644 index 00000000..828b63b3 --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/specs/image-generation-access-control/spec.md @@ -0,0 +1,118 @@ +## ADDED Requirements + +### Requirement: Group image generation capability +The system SHALL store a group-level `allow_image_generation` capability flag and SHALL expose it through admin group create, update, list, and detail APIs. + +#### Scenario: New group defaults to image generation disabled +- **WHEN** an admin creates a group without providing `allow_image_generation` +- **THEN** the persisted group has `allow_image_generation=false` + +#### Scenario: Existing image-capable platform groups are backfilled +- **WHEN** the migration is applied to existing groups +- **THEN** existing `openai`, `gemini`, and `antigravity` groups have `allow_image_generation=true` +- **AND** existing `anthropic` groups have `allow_image_generation=false` + +#### Scenario: Admin enables image generation on an ordinary coding group +- **WHEN** an admin updates an `openai` group with `allow_image_generation=true` +- **THEN** the group can use image generation paths subject to the billing requirements + +#### Scenario: Admin disables image generation on an ordinary coding group +- **WHEN** an admin updates an `openai` group with `allow_image_generation=false` +- **THEN** the group can still use non-image text model requests +- **AND** image generation intents are denied before upstream dispatch + +### Requirement: Image generation intent detection +The system SHALL classify a request as an image generation intent before upstream account scheduling when the endpoint or request body can produce generated images. + +#### Scenario: Images endpoint is an image generation intent +- **WHEN** a request targets `/v1/images/generations`, `/v1/images/edits`, `/images/generations`, or `/images/edits` +- **THEN** the request is classified as an image generation intent + +#### Scenario: Responses request with image-only model is an image generation intent +- **WHEN** a `/v1/responses` request has a requested model whose normalized name starts with `gpt-image-` +- **THEN** the request is classified as an image generation intent before any model rewrite + +#### Scenario: Responses request with image_generation tool is an image generation intent +- **WHEN** a `/v1/responses` request contains any `tools[]` entry with `type == "image_generation"` +- **THEN** the request is classified as an image generation intent + +#### Scenario: Responses request with image_generation tool_choice is an image generation intent +- **WHEN** a `/v1/responses` request contains `tool_choice` that explicitly selects `image_generation` +- **THEN** the request is classified as an image generation intent even if `tools[]` is malformed or absent + +#### Scenario: Generic tool_choice required is not sufficient by itself +- **WHEN** a `/v1/responses` request contains `tool_choice="required"` +- **AND** the request does not contain an `image_generation` tool +- **THEN** the request is not classified as an image generation intent because of `tool_choice` alone + +#### Scenario: Text-only gpt-5.4 request is not an image generation intent +- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"` without `image_generation` tool and without image `tool_choice` +- **THEN** the request is not classified as an image generation intent + +#### Scenario: Intent is checked before and after service-side mutation +- **WHEN** the service mutates a `/v1/responses` request by injecting `image_generation` or rewriting `gpt-image-*` to a Responses text model plus image tool +- **THEN** the final mutated request is checked against the same image generation intent rules before upstream dispatch + +### Requirement: Disabled groups reject explicit image generation +The system SHALL reject explicit image generation intents for groups with `allow_image_generation=false` before selecting or calling an upstream account. + +#### Scenario: Disabled group rejects Images API +- **WHEN** a group has `allow_image_generation=false` +- **AND** a user calls `/v1/images/generations` +- **THEN** the system returns HTTP 403 with error type `permission_error` +- **AND** no upstream account is selected +- **AND** no usage log is written + +#### Scenario: Disabled group rejects Responses image tool +- **WHEN** a group has `allow_image_generation=false` +- **AND** a user calls `/v1/responses` with `tools:[{"type":"image_generation"}]` +- **THEN** the system returns HTTP 403 with error type `permission_error` +- **AND** no upstream account is selected +- **AND** no usage log is written + +#### Scenario: Disabled group rejects Responses image-only model rewrite +- **WHEN** a group has `allow_image_generation=false` +- **AND** a user calls `/v1/responses` with `model` starting with `gpt-image-` +- **THEN** the system returns HTTP 403 with error type `permission_error` +- **AND** the request is not rewritten to a text Responses model + +#### Scenario: Disabled group permits normal coding request +- **WHEN** a group has `allow_image_generation=false` +- **AND** a user calls `/v1/responses` with `model="gpt-5.4"` and no image generation intent +- **THEN** the request proceeds through the normal text forwarding path + +### Requirement: Codex image tool injection respects group capability +The system SHALL only inject the OpenAI Responses `image_generation` tool and bridge instructions for Codex clients when the request group has `allow_image_generation=true`. + +#### Scenario: Codex request in enabled group receives image tool +- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=true` +- **AND** the request has no `image_generation` tool +- **THEN** the system injects the existing `image_generation` tool payload +- **AND** the system appends the existing Codex image bridge instructions + +#### Scenario: Codex request in disabled group does not receive image tool +- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false` +- **AND** the request has no explicit image generation intent +- **THEN** the system does not inject `image_generation` +- **AND** the system does not append image bridge instructions +- **AND** the request proceeds as a text request + +#### Scenario: Codex explicit image request in disabled group is denied +- **WHEN** a Codex CLI `/v1/responses` request belongs to a group with `allow_image_generation=false` +- **AND** the request explicitly contains `image_generation` +- **THEN** the system returns HTTP 403 with error type `permission_error` + +### Requirement: Channel model restrictions remain enforced +The system SHALL keep existing channel model restriction behavior for image and non-image OpenAI requests, including when the advanced OpenAI account scheduler is enabled. + +#### Scenario: Advanced scheduler blocks restricted requested model +- **WHEN** a channel has `restrict_models=true` +- **AND** the requested model is not allowed by channel pricing or mapping rules +- **AND** the OpenAI advanced scheduler path is used +- **THEN** the request is rejected before upstream account selection succeeds + +#### Scenario: Image generation flag does not bypass channel restrictions +- **WHEN** a group has `allow_image_generation=true` +- **AND** the channel restriction rejects the requested or billing model +- **THEN** the image generation request is rejected +- **AND** no upstream image request is sent diff --git a/openspec/changes/add-image-generation-billing-controls/specs/image-generation-billing-accounting/spec.md b/openspec/changes/add-image-generation-billing-controls/specs/image-generation-billing-accounting/spec.md new file mode 100644 index 00000000..90176e33 --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/specs/image-generation-billing-accounting/spec.md @@ -0,0 +1,225 @@ +## ADDED Requirements + +### Requirement: Image multiplier mode +The system SHALL calculate image generation cost with group image prices and a selectable image multiplier mode. By default image billing SHALL share the existing effective group multiplier; when `image_rate_independent=true`, image billing SHALL use `image_rate_multiplier`. + +#### Scenario: Default image billing shares current effective group multiplier +- **WHEN** a group has `rate_multiplier=0.15` +- **AND** `image_rate_independent=false` +- **AND** `image_price_1k=0.2` +- **AND** a successful image request produces one `1K` image +- **THEN** `actual_cost` is `0.03` +- **AND** the calculation matches current default behavior + +#### Scenario: User-specific token multiplier still applies in shared mode +- **WHEN** a user has a user-group token multiplier override of `0.2` +- **AND** the group has `image_rate_independent=false` +- **AND** `image_price_1k=0.5` +- **AND** a successful image request produces one `1K` image +- **THEN** `actual_cost` is `0.1` +- **AND** the applied image multiplier is the same effective multiplier used by token billing + +#### Scenario: Independent image multiplier allows direct final price +- **WHEN** a group has `rate_multiplier=0.15` +- **AND** `image_rate_independent=true` +- **AND** `image_rate_multiplier=1` +- **AND** `image_price_1k=0.2` +- **AND** a successful image request produces one `1K` image +- **THEN** `actual_cost` is `0.2` +- **AND** ordinary `rate_multiplier=0.15` is not applied to the image cost + +#### Scenario: Independent image multiplier supports image discounts +- **WHEN** a group has `image_rate_independent=true` +- **AND** `image_rate_multiplier=0.5` +- **AND** `image_price_1k=0.2` +- **AND** a successful image request produces two `1K` images +- **THEN** `total_cost` is `0.4` +- **AND** `actual_cost` is `0.2` + +#### Scenario: Migration preserves existing image price behavior +- **WHEN** an existing group has `rate_multiplier=0.15` and `image_price_1k=1.3333333333` +- **AND** the migration is applied +- **THEN** the stored `image_price_1k` remains `1.3333333333` +- **AND** the stored `image_rate_independent` is `false` +- **AND** the stored `image_rate_multiplier` is `1` +- **AND** default-mode image billing still produces the historical final price within decimal precision + +#### Scenario: Omitted update fields preserve existing multiplier mode +- **WHEN** an admin updates a group without sending `image_rate_independent` +- **AND** without sending `image_rate_multiplier` +- **THEN** the stored image multiplier mode and image multiplier value remain unchanged + +#### Scenario: Image multiplier can be zero only by explicit independent mode configuration +- **WHEN** a group has `image_rate_independent=true` +- **AND** `image_rate_multiplier=0` +- **AND** a successful image request produces one image +- **THEN** the image request is free +- **AND** this free-image behavior does not occur unless the group explicitly enables independent image multiplier mode with zero multiplier + +### Requirement: Responses image output accounting +The system SHALL count generated image outputs from OpenAI Responses stream, non-stream, and passthrough paths and SHALL return the count in `OpenAIForwardResult.ImageCount`. + +#### Scenario: Non-stream Responses image tool output is counted +- **WHEN** a non-stream `/v1/responses` upstream response contains `output[]` item with `type == "image_generation_call"` and non-empty `result` +- **THEN** `OpenAIForwardResult.ImageCount` equals the number of unique final image outputs +- **AND** `OpenAIForwardResult.ImageSize` is the normalized image size tier + +#### Scenario: Stream Responses output item is counted +- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.output_item.done"` +- **AND** the event item has `type == "image_generation_call"` and non-empty `result` +- **THEN** the streaming result increments the unique final image output count + +#### Scenario: Stream Responses completed output is counted +- **WHEN** a stream `/v1/responses` upstream SSE event has `type == "response.completed"` +- **AND** `response.output[]` contains final image generation outputs +- **THEN** the streaming result counts those images without double-counting images already seen in `response.output_item.done` + +#### Scenario: Partial image events are not billed as completed images +- **WHEN** a stream response contains `partial_image` events +- **THEN** those partial events do not increment `ImageCount` +- **AND** only final image generation outputs increment `ImageCount` + +#### Scenario: gpt-5.4 image tool request is billed as image +- **WHEN** a `/v1/responses` request uses `model="gpt-5.4"` or `model="gpt-5.5"` +- **AND** the request includes an `image_generation` tool +- **AND** the upstream response contains one final image output +- **THEN** the usage log has `image_count=1` +- **AND** the usage log has `billing_mode="image"` +- **AND** image pricing, not token pricing, determines `actual_cost` + +#### Scenario: Image output with zero usage is still billed +- **WHEN** an upstream Responses result contains final image output +- **AND** the upstream result has zero or missing token usage +- **THEN** the system writes a usage log +- **AND** the system bills using image pricing + +#### Scenario: Responses image request records accompanying token usage +- **WHEN** a `/v1/responses` image tool request returns final images and token usage +- **THEN** the usage log records input tokens, output tokens, image output tokens, and image count +- **AND** the applied billing mode remains `image` + +#### Scenario: Responses image request does not introduce hybrid billing by default +- **WHEN** a `/v1/responses` image tool request returns final images and text tokens +- **THEN** the request is billed by image pricing under this change +- **AND** non-image token charges are not added unless a future explicit hybrid billing mode is implemented + +### Requirement: OpenAI Images API output accounting +The system SHALL count generated images from dedicated OpenAI Images API stream and non-stream paths and SHALL set `ImageCount` for successful image responses. + +#### Scenario: Images non-stream data array is counted +- **WHEN** `/v1/images/generations` returns a non-stream JSON response with top-level `data[]` +- **THEN** `ImageCount` equals the length of `data[]` + +#### Scenario: Images stream data array is counted +- **WHEN** `/v1/images/generations` stream response emits SSE data containing top-level `data[]` +- **THEN** `ImageCount` equals the maximum final data array count observed for the request + +#### Scenario: Images stream completed event is counted +- **WHEN** `/v1/images/generations` stream response emits `image_generation.completed` with a final image payload +- **THEN** the stream result counts one final image output + +#### Scenario: Images stream Responses-form event is counted +- **WHEN** an Images API upstream path emits Responses-form `response.output_item.done` or `response.completed` events with final image outputs +- **THEN** the stream result counts final image outputs using the same de-duplication rules as Responses + +### Requirement: Channel image billing uses actual image count +The system SHALL use actual generated image count for channel `billing_mode=image` pricing and SHALL NOT bill multi-image requests as a single request. + +#### Scenario: OpenAI channel image billing counts multiple images +- **WHEN** a channel image pricing entry resolves to unit price `0.25` +- **AND** an OpenAI image request produces three images +- **THEN** `total_cost` is `0.75` before the selected image multiplier is applied +- **AND** `RequestCount` passed into unified pricing is `3` + +#### Scenario: Gateway channel image billing counts multiple images +- **WHEN** a non-OpenAI gateway image path produces two images +- **AND** channel image pricing resolves for the billing model +- **THEN** `RequestCount` passed into unified pricing is `2` + +#### Scenario: Channel image pricing uses shared multiplier by default +- **WHEN** a channel image pricing entry resolves to unit price `0.25` +- **AND** the group has ordinary effective multiplier `0.15` +- **AND** the group has `image_rate_independent=false` +- **AND** the image request produces one image +- **THEN** `actual_cost` is `0.0375` + +#### Scenario: Channel image pricing uses independent image multiplier when enabled +- **WHEN** a channel image pricing entry resolves to unit price `0.25` +- **AND** the group has ordinary effective multiplier `0.15` +- **AND** the group has `image_rate_independent=true` +- **AND** the group has `image_rate_multiplier=1` +- **AND** the image request produces one image +- **THEN** `actual_cost` is `0.25` +- **AND** ordinary effective multiplier `0.15` is not applied + +#### Scenario: Account stats image pricing receives image count +- **WHEN** account stats pricing uses `billing_mode=image` +- **AND** the request produces multiple images +- **THEN** account stats cost is calculated with the actual image count + +### Requirement: Image size tier normalization +The system SHALL normalize OpenAI image sizes to explicit billing tiers for billing only. The system SHALL NOT reject requests locally because of an unknown or provider-invalid `size`; it SHALL forward the original size parameter upstream and let the official upstream API decide whether the request is valid. + +#### Scenario: OpenAI 1024 square maps to 1K +- **WHEN** an OpenAI image request specifies `size="1024x1024"` +- **THEN** `ImageSize` is `1K` + +#### Scenario: OpenAI landscape and portrait large sizes map to 2K +- **WHEN** an OpenAI image request specifies `1536x1024`, `1024x1536`, `1792x1024`, `1024x1792`, `2048x2048`, `2048x1152`, or `1152x2048` +- **THEN** `ImageSize` is `2K` + +#### Scenario: OpenAI gpt-image-2 4K presets map to 4K +- **WHEN** an OpenAI `gpt-image-2` image request specifies `3840x2160` or `2160x3840` +- **THEN** `ImageSize` is `4K` + +#### Scenario: OpenAI auto size maps to 2K +- **WHEN** an OpenAI image request omits size or specifies `size="auto"` +- **THEN** `ImageSize` is `2K` + +#### Scenario: Custom OpenAI size is forwarded without local validation +- **WHEN** an OpenAI image request specifies a custom explicit `WIDTHxHEIGHT` size +- **THEN** the system forwards the request upstream +- **AND** `ImageSize` is normalized to `2K` or `4K` for billing + +#### Scenario: Responses image tool without model uses default image billing model +- **WHEN** a `/v1/responses` request uses an `image_generation` tool without `tool.model` +- **THEN** image size validation and image billing use `gpt-image-2` as the image billing model + +#### Scenario: Invalid OpenAI size constraints are delegated upstream +- **WHEN** an OpenAI image request specifies an explicit size that fails OpenAI size constraints +- **THEN** the system forwards the request upstream +- **AND** any invalid-size error comes from the upstream provider response + +#### Scenario: Custom OpenAI size tier mapping +- **WHEN** a custom size cannot be parsed as positive `WIDTHxHEIGHT` +- **THEN** `ImageSize` is `2K` +- **WHEN** a custom size parses as positive `WIDTHxHEIGHT` +- **AND** `WIDTH * HEIGHT` is no more than `2560x1440` +- **THEN** `ImageSize` is `2K` +- **WHEN** a custom size parses as positive `WIDTHxHEIGHT` +- **AND** `WIDTH * HEIGHT` exceeds `2560x1440` +- **THEN** `ImageSize` is `4K` + +### Requirement: Image usage log semantics +The system SHALL write usage logs for successful image generation with image billing metadata that matches the applied image pricing path. + +#### Scenario: Image usage log records image billing mode +- **WHEN** a successful request has `ImageCount > 0` +- **THEN** the usage log has `billing_mode="image"` +- **AND** the usage log records `image_count` +- **AND** the usage log records `image_size` when a normalized size tier is available + +#### Scenario: Shared mode image usage log records shared multiplier +- **WHEN** a successful image request is billed with `image_rate_independent=false` +- **AND** the effective ordinary multiplier is `0.15` +- **THEN** `usage_logs.rate_multiplier` is `0.15` + +#### Scenario: Independent mode image usage log records image multiplier +- **WHEN** a successful image request is billed with `image_rate_independent=true` +- **AND** `image_rate_multiplier=0.5` +- **THEN** `usage_logs.rate_multiplier` is `0.5` + +#### Scenario: Token request usage log is unchanged +- **WHEN** a successful non-image token request is billed +- **THEN** `usage_logs.rate_multiplier` continues to record the ordinary token multiplier +- **AND** `image_count` is `0` diff --git a/openspec/changes/add-image-generation-billing-controls/tasks.md b/openspec/changes/add-image-generation-billing-controls/tasks.md new file mode 100644 index 00000000..16a35654 --- /dev/null +++ b/openspec/changes/add-image-generation-billing-controls/tasks.md @@ -0,0 +1,72 @@ +## 1. Data Model And Migration + +- [x] 1.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `backend/ent/schema/group.go`. +- [x] 1.2 Create a new idempotent SQL migration after `133_affiliate_rebate_freeze.sql` for the three group columns. +- [x] 1.3 Backfill existing `openai`, `gemini`, and `antigravity` groups to `allow_image_generation=true` and `anthropic` groups to `false`. +- [x] 1.4 Backfill all existing groups to `image_rate_independent=false` and `image_rate_multiplier=1` without changing existing `image_price_1k/2k/4k`. +- [x] 1.5 Regenerate or update Ent generated group fields, predicates, create/update setters, and query projections. +- [x] 1.6 Add the new fields to backend group domain/service structs, admin create/update inputs, admin responses, and group serialization. + +## 2. Admin API And Frontend + +- [x] 2.1 Add `allow_image_generation`, `image_rate_independent`, and `image_rate_multiplier` to `CreateGroupRequest` and `UpdateGroupRequest`. +- [x] 2.2 Validate `image_rate_multiplier >= 0` and keep negative image prices using the existing clear-price behavior only for `image_price_*`. +- [x] 2.3 Add the new fields to `frontend/src/types/index.ts` group, create, and update interfaces. +- [x] 2.4 Ensure omitted update fields do not overwrite existing image generation and multiplier mode settings. +- [x] 2.5 Update `frontend/src/views/admin/GroupsView.vue` create/edit forms with a 生图开关, 生图倍率是否独立开关, and conditional image multiplier input. +- [x] 2.6 Add a live final-price preview for `image_price_1k/2k/4k` under shared and independent multiplier modes. +- [x] 2.7 Update group form help text to state that default image billing shares the existing group effective multiplier and independent mode uses the image multiplier input. +- [x] 2.8 Update i18n strings for the new controls and image multiplier mode explanation. + +## 3. Image Generation Access Control + +- [x] 3.1 Implement a shared helper that detects image generation intent from endpoint, requested model, `tools[]`, and `tool_choice`. +- [x] 3.2 Gate `/v1/images/generations` and `/v1/images/edits` in `backend/internal/handler/openai_images.go` after request parsing and before billing eligibility/account scheduling. +- [x] 3.3 Gate `/v1/responses` explicit `image_generation` tool requests in `backend/internal/service/openai_gateway_service.go` before upstream account scheduling. +- [x] 3.4 Prevent `normalizeOpenAIResponsesImageOnlyModel` from rewriting `gpt-image-*` Responses requests when the group does not allow image generation. +- [x] 3.5 Skip Codex `image_generation` auto-injection and image bridge instructions when the group does not allow image generation. +- [x] 3.6 Re-run image intent detection after service-side request mutation and before upstream dispatch. +- [x] 3.7 Ensure OpenAI advanced scheduler paths apply the same channel `RestrictModels` checks as the load-aware path. + +## 4. Responses Image Output Accounting + +- [x] 4.1 Add shared parsers for final `image_generation_call.result` outputs in non-stream JSON and SSE payloads. +- [x] 4.2 Extend `openaiStreamingResult` with image count, image size tier, and image billing model fields. +- [x] 4.3 Update `handleStreamingResponse` to count final image outputs while preserving existing stream forwarding and usage parsing. +- [x] 4.4 Update `handleStreamingResponsePassthrough` with the same image output counting. +- [x] 4.5 Update `handleNonStreamingResponse` to count final image outputs from `output[]`. +- [x] 4.6 Update `handleNonStreamingResponsePassthrough` with the same non-stream image output counting. +- [x] 4.7 Populate `OpenAIForwardResult.ImageCount`, `ImageSize`, and image billing model for `gpt-5.4` / `gpt-5.5 + image_generation` requests. + +## 5. Images API Accounting And Size Tiers + +- [x] 5.1 Extend OpenAI Images API-key stream counting to handle `image_generation.completed`, `response.output_item.done`, and `response.completed`. +- [x] 5.2 Reuse the same final-image de-duplication rules across Images API and Responses API paths. +- [x] 5.3 Keep unknown explicit OpenAI image sizes pass-through and delegate invalid-size errors to upstream. +- [x] 5.4 Map documented OpenAI image sizes to `1K`/`2K`/`4K` billing tiers without rewriting request parameters. +- [x] 5.5 Classify custom OpenAI `WIDTHxHEIGHT` sizes by `2560x1440` total-pixel boundary, falling back to `2K` when unparseable. + +## 6. Billing And Usage Logs + +- [x] 6.1 Add an image multiplier resolver: shared mode uses the current effective group multiplier, independent mode uses `apiKey.Group.ImageRateMultiplier`. +- [x] 6.2 Update `CalculateImageCost` or its caller contract so image costs use the resolved image multiplier. +- [x] 6.3 Set image usage log `RateMultiplier` to the applied image multiplier; keep token logs unchanged. +- [x] 6.4 Change OpenAI channel image billing `RequestCount` from `1` to `result.ImageCount`. +- [x] 6.5 Change non-OpenAI gateway channel image billing `RequestCount` from `1` to `result.ImageCount`. +- [x] 6.6 Pass actual image count into account stats pricing for `billing_mode=image`. +- [x] 6.7 Ensure `ImageCount > 0` writes a usage log and bills even when upstream token usage is zero. +- [x] 6.8 Record accompanying token usage for Responses image tool requests while keeping default billing mode as `image`. + +## 7. Tests And Documentation + +- [x] 7.1 Add backend tests for disabled group rejecting `/v1/images/*`, `gpt-image-*` Responses, explicit `image_generation`, and image `tool_choice`. +- [x] 7.2 Add backend tests proving disabled Codex groups do not receive injected image tools while enabled Codex groups still do. +- [x] 7.3 Add backend tests proving omitted group update fields preserve existing image generation and multiplier mode settings. +- [x] 7.4 Add Responses stream and non-stream tests for `gpt-5.4` / `gpt-5.5 + image_generation` image counting and image billing. +- [x] 7.5 Add Images API stream tests for `image_generation.completed`, `response.output_item.done`, and `response.completed` counting. +- [x] 7.6 Add billing tests for shared mode `rate_multiplier=0.15`, `image_price_1k=0.2`, final `actual_cost=0.03`. +- [x] 7.7 Add billing tests for independent mode `rate_multiplier=0.15`, `image_rate_multiplier=1`, `image_price_1k=0.2`, final `actual_cost=0.2`. +- [x] 7.8 Add channel image billing tests proving multi-image requests use `RequestCount=ImageCount` in both shared and independent multiplier modes. +- [x] 7.9 Add size-tier tests for known OpenAI sizes and unknown explicit size pass-through. +- [x] 7.10 Add Responses image tool tests proving token usage is recorded but default billing remains image-mode only. +- [x] 7.11 Update `2ue/image-billing-risk-analysis.md` or add a linked follow-up note that points to this OpenSpec change as the normalized solution. diff --git a/openspec/changes/image-generation-concurrency-isolation/.openspec.yaml b/openspec/changes/image-generation-concurrency-isolation/.openspec.yaml new file mode 100644 index 00000000..e5764a1d --- /dev/null +++ b/openspec/changes/image-generation-concurrency-isolation/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-05-03 diff --git a/openspec/changes/image-generation-concurrency-isolation/design.md b/openspec/changes/image-generation-concurrency-isolation/design.md new file mode 100644 index 00000000..182df519 --- /dev/null +++ b/openspec/changes/image-generation-concurrency-isolation/design.md @@ -0,0 +1,70 @@ +## Overview + +本次只实现“图片独立并发开关”,不实现外部图片网关的运行时代码。目标是在最大程度不改变现有行为的前提下,为图片流式长连接提供服务级资源保护。 + +## Current Constraints + +- 当前 Redis 并发槽位只有用户和账号维度,键语义是 `concurrency:user:*` 与 `concurrency:account:*`。 +- 图片接口和普通 Responses 在同一个 Go 服务内运行,共享进程、HTTP 上游连接池和账号调度。 +- Codex OAuth 路径会自动注入 `image_generation` tool;这个注入表示“模型具备工具能力”,不等价于当前请求一定会生图。 +- `/v1/responses` 在 handler 入口只能可靠识别显式图片意图:image 模型、请求体已有 image tool、或 tool_choice 明确选择 image_generation。 +- 图片实际产物计数与计费仍以 service 层的最终输出解析为准。 + +## Decisions + +### 1. 默认关闭,保持兼容 + +新增配置: + +- `gateway.image_concurrency.enabled`,默认 `false`。 +- `gateway.image_concurrency.max_concurrent_requests`,默认 `0`,表示不限制。 +- `gateway.image_concurrency.overflow_mode`,默认 `reject`,可选 `reject` / `wait`。 +- `gateway.image_concurrency.wait_timeout_seconds`,默认 `30`,仅 `overflow_mode=wait` 生效。 +- `gateway.image_concurrency.max_waiting_requests`,默认 `100`,仅 `overflow_mode=wait` 生效,限制当前进程内图片等待队列。 + +只有当 `enabled=true` 且 `max_concurrent_requests>0` 时才启用图片独立并发限制。默认配置不改变任何现有流量行为。 + +### 2. 进程级信号量作为第一阶段隔离 + +本次使用进程内有界信号量做服务级图片并发限制。原因: + +- 不扩展现有 Redis `ConcurrencyCache` 接口,避免影响用户/账号并发的既有语义。 +- 不新增迁移,不改变分组已有字段。 +- 单实例部署可立即保护进程资源。 +- 多实例部署时该限制按实例生效;文档必须明确总图片并发约等于 `实例数 × max_concurrent_requests`。 + +### 3. 限制对象只包含明确图片意图 + +纳入限制: + +- `/v1/images/generations` +- `/v1/images/edits` +- `/v1/responses` 中入口请求已明确包含图片意图:image 模型、`tools[].type=image_generation`、`tool_choice` 明确选择 image_generation。 + +暂不纳入限制: + +- 普通 Codex 请求因为服务端自动注入 image tool 而具备生图能力,但入口请求本身未明确要求生图。 + +这样避免把普通编码请求错误算作图片并发。后续若要对“模型运行中动态调用 image tool”做更细粒度隔离,需要在工具调用实际发生时获得可阻塞的事件,目前当前代码没有这种入口级阻塞点。 + +### 4. 限流行为 + +- `overflow_mode=reject` 时,未开始流式响应直接返回 HTTP `429`,错误类型 `rate_limit_error`。 +- `overflow_mode=wait` 时,请求在当前进程内等待图片并发槽位,超过 `wait_timeout_seconds` 或超过 `max_waiting_requests` 后返回 HTTP `429`。 +- 已开始流式响应时,使用现有 `handleStreamingAwareError` 写 SSE 错误事件。 +- 图片并发限制命中或等待超时不触发账号 failover,不记录为上游账号失败。 +- `gateway.image_stream_data_interval_timeout` 是上游图片流数据空闲超时,不用于图片排队等待。 + +### 5. 与外部图片网关的关系 + +本次不实现外部图片网关代码。外部网关方案沉淀到 `2ue` 文档: + +- 推荐由 Caddy/Nginx/API Gateway 按 `/v1/images/*` 分流。 +- `/v1/responses` 的图片 tool 请求不能仅靠 path 分流,必须在前置层读取 body 或保留主服务兜底。 +- 即使未来拆出图片网关,主网关仍保留图片 intent 检测、开关和计费兜底,避免直连或漏判绕过。 + +## Risks And Mitigations + +- 风险:进程级限制在多实例部署下不是全局严格限制。缓解:文档明确容量计算,后续可基于 Redis 扩展为集群级图片并发。 +- 风险:Codex 自动注入 image tool 后,普通编码请求未被图片限流。缓解:这是有意选择,避免误伤普通请求;实际输出图片仍按图片计费。 +- 风险:图片请求在账号槽位前被拒绝可能改变排队体验。缓解:仅当独立开关启用时生效,默认关闭;429 明确提示图片并发达到上限。 diff --git a/openspec/changes/image-generation-concurrency-isolation/proposal.md b/openspec/changes/image-generation-concurrency-isolation/proposal.md new file mode 100644 index 00000000..d27b3caf --- /dev/null +++ b/openspec/changes/image-generation-concurrency-isolation/proposal.md @@ -0,0 +1,28 @@ +## Why + +图片生成流式请求会比普通文本流式请求占用更长的连接、goroutine、HTTP 上游连接和账号/用户槽位。当前图片能力已经具备独立计费与更长流式超时,但仍缺少默认关闭的图片专属并发隔离开关,图片高并发时仍可能挤压普通文本流式接口。 + +## What Changes + +- 新增服务级图片独立并发开关,默认关闭,不改变现有已部署分组和普通文本请求行为。 +- 新增图片全局并发上限配置;开启后仅限制已明确是图片生成意图的请求。 +- 新增图片并发满载后的溢出策略配置:默认立即拒绝,也可配置等待槽位和等待超时。 +- 将图片并发限制覆盖 `/v1/images/generations`、`/v1/images/edits` 和 `/v1/responses` 显式图片生成请求。 +- 保留当前图片生成开关、图片计费、图片流式续读与超时语义。 +- 不在本次代码实现外部独立图片网关;只把外部网关拆分方案沉淀到本地文档。 + +## Capabilities + +### New Capabilities +- `image-generation-concurrency-isolation`: 图片生成请求的独立并发开关、并发上限、429 行为和外部网关落地建议。 + +### Modified Capabilities +- `image-stream-resilience`: 图片流式续读能力在独立并发开启时受到图片专属并发上限保护,但流式续读与计费契约不变。 + +## Impact + +- 影响 `backend/internal/config/config.go` 的 gateway 配置字段、默认值和校验。 +- 影响 `backend/internal/handler/openai_images.go` 与 `backend/internal/handler/openai_gateway_handler.go` 的图片请求入口限流。 +- 影响 `deploy/config.example.yaml` 的示例配置与说明。 +- 影响后端测试:配置默认值/校验、图片接口限流、Responses 显式 image tool 限流。 +- 新增或更新 `2ue` 本地分析文档,记录外部独立图片网关只作为后续部署方案,不在本次代码落地。 diff --git a/openspec/changes/image-generation-concurrency-isolation/specs/image-generation-concurrency-isolation/spec.md b/openspec/changes/image-generation-concurrency-isolation/specs/image-generation-concurrency-isolation/spec.md new file mode 100644 index 00000000..7b55e979 --- /dev/null +++ b/openspec/changes/image-generation-concurrency-isolation/specs/image-generation-concurrency-isolation/spec.md @@ -0,0 +1,82 @@ +# image-generation-concurrency-isolation Specification + +## ADDED Requirements + +### Requirement: Image concurrency isolation is opt-in + +The system SHALL keep image concurrency isolation disabled by default. + +#### Scenario: default config keeps existing behavior +- **GIVEN** the deployment does not set `gateway.image_concurrency.enabled` +- **WHEN** image generation requests are received +- **THEN** no new image-specific concurrency limit is applied +- **AND** existing user/account concurrency and billing behavior remains unchanged + +### Requirement: Dedicated image concurrency limit + +The system SHALL provide an opt-in service-level image concurrency limit controlled by gateway configuration. + +#### Scenario: explicit image endpoint is limited +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** `gateway.image_concurrency.max_concurrent_requests=1` +- **AND** one image generation request is already active +- **WHEN** another `/v1/images/generations` or `/v1/images/edits` request arrives +- **THEN** the second request is rejected with HTTP `429` +- **AND** the error type is `rate_limit_error` + +#### Scenario: explicit Responses image generation request is limited +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** `gateway.image_concurrency.max_concurrent_requests=1` +- **AND** `gateway.image_concurrency.overflow_mode=reject` +- **AND** one image generation request is already active +- **WHEN** a `/v1/responses` request explicitly contains `tools[].type=image_generation`, an image model, or `tool_choice` selecting `image_generation` +- **THEN** the request is rejected with HTTP `429` +- **AND** it is not retried through account failover + +#### Scenario: image request waits for a slot +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** `gateway.image_concurrency.max_concurrent_requests=1` +- **AND** `gateway.image_concurrency.overflow_mode=wait` +- **AND** `gateway.image_concurrency.wait_timeout_seconds` is greater than zero +- **AND** one image generation request is already active +- **WHEN** another explicit image generation request arrives +- **AND** the active image generation request releases its slot before the wait timeout +- **THEN** the waiting image generation request acquires the slot and continues + +#### Scenario: image wait times out +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** `gateway.image_concurrency.max_concurrent_requests=1` +- **AND** `gateway.image_concurrency.overflow_mode=wait` +- **AND** one image generation request is already active +- **WHEN** another explicit image generation request waits longer than `gateway.image_concurrency.wait_timeout_seconds` +- **THEN** the waiting request is rejected with HTTP `429` +- **AND** the error type is `rate_limit_error` + +#### Scenario: image waiting queue is full +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** `gateway.image_concurrency.overflow_mode=wait` +- **AND** `gateway.image_concurrency.max_waiting_requests` is already reached +- **WHEN** another explicit image generation request arrives +- **THEN** the request is rejected with HTTP `429` +- **AND** it does not wait for account scheduling + +### Requirement: Text requests are not image-limited + +The system SHALL NOT apply the image concurrency limit to requests without explicit image generation intent. + +#### Scenario: normal coding request bypasses image limit +- **GIVEN** `gateway.image_concurrency.enabled=true` +- **AND** the image concurrency limit is full +- **WHEN** a `/v1/responses` request uses a text model and does not explicitly contain image generation intent +- **THEN** the image concurrency limiter does not reject it +- **AND** normal user/account concurrency handling continues + +### Requirement: External image gateway remains a deployment pattern + +The system SHALL document external image gateway routing as a deployment option without adding runtime forwarding code in this change. + +#### Scenario: operator reads local design note +- **GIVEN** the repository documentation is available +- **WHEN** an operator evaluates isolating image traffic into a separate service +- **THEN** local `2ue` documentation describes which paths are safe to route by path +- **AND** explains why `/v1/responses` image tool requests require body-aware routing or main-gateway fallback diff --git a/openspec/changes/image-generation-concurrency-isolation/tasks.md b/openspec/changes/image-generation-concurrency-isolation/tasks.md new file mode 100644 index 00000000..697c2853 --- /dev/null +++ b/openspec/changes/image-generation-concurrency-isolation/tasks.md @@ -0,0 +1,28 @@ +## 1. Spec and documentation + +- [x] 1.1 Create OpenSpec proposal, design, tasks, and capability spec for image concurrency isolation. +- [x] 1.2 Add a local `2ue` note for the external image gateway deployment pattern and current non-goals. + +## 2. Config + +- [x] 2.1 Add `gateway.image_concurrency.enabled` and `gateway.image_concurrency.max_concurrent_requests` config fields. +- [x] 2.2 Register defaults that keep existing behavior unchanged. +- [x] 2.3 Validate max concurrent requests as non-negative. +- [x] 2.4 Update `deploy/config.example.yaml` with safe usage notes. +- [x] 2.5 Add image concurrency overflow mode, wait timeout, and max waiting request config. + +## 3. Runtime limiter + +- [x] 3.1 Implement a process-level image concurrency limiter with resize-on-config-read behavior. +- [x] 3.2 Acquire/release the limiter around `/v1/images/generations` and `/v1/images/edits` before account scheduling. +- [x] 3.3 Acquire/release the limiter around explicit `/v1/responses` image generation intent before account scheduling. +- [x] 3.4 Ensure limiter rejections return `429 rate_limit_error` and do not trigger account failover. +- [x] 3.5 Support `reject` and `wait` overflow modes with bounded wait timeout and waiting queue size. + +## 4. Tests and verification + +- [x] 4.1 Add config default and validation tests. +- [x] 4.2 Add handler tests for image endpoint limiter rejection. +- [x] 4.3 Add handler tests proving text-only Responses requests are not rejected by the image limiter. +- [x] 4.4 Run focused Go tests for config and OpenAI handler/service paths. +- [x] 4.5 Add limiter tests for wait success, wait timeout, and waiting queue overflow. diff --git a/openspec/changes/image-stream-resilience/.openspec.yaml b/openspec/changes/image-stream-resilience/.openspec.yaml new file mode 100644 index 00000000..2988acfa --- /dev/null +++ b/openspec/changes/image-stream-resilience/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-05-02 diff --git a/openspec/changes/image-stream-resilience/design.md b/openspec/changes/image-stream-resilience/design.md new file mode 100644 index 00000000..4ad05512 --- /dev/null +++ b/openspec/changes/image-stream-resilience/design.md @@ -0,0 +1,46 @@ +## Context + +现有普通 Responses 流式已经具备较完整的断连续写能力:客户端写失败后可继续 drain 上游,并且具备数据间隔超时和 keepalive。图片流式路径目前仍然采用更直接的读写方式,客户端写失败会立即返回,上游读取也更容易跟随客户端取消而结束。 + +本次变更只针对图片流式路径,不改变普通文本流式路径的配置和行为。系统已经存在普通流式的后端超时配置,因此这里不引入页面级超时设置;图片流式只需要独立的后端默认值,让图片生成有更长的容忍窗口。 + +## Goals / Non-Goals + +**Goals:** +- 图片流式在客户端断开后继续读取上游,尽量保留最终图片结果与计费结果。 +- 图片流式使用独立于普通流式的超时与 keepalive 默认值。 +- 不修改现有普通流式配置项的含义,不要求管理员新增页面配置。 +- 维持图片计费与图片结果计数的一致性。 + +**Non-Goals:** +- 不设计新的前端配置页面。 +- 不修改普通文本流式的超时策略。 +- 不改变图片计费公式或分组倍率语义。 + +## Decisions + +1. **使用独立的图片流式配置键** + - 选择:在后端配置中增加图片流式专用 `image_stream_data_interval_timeout` / `image_stream_keepalive_interval`。 + - 原因:图片流式耗时显著更长,复用普通流式默认值会过早触发超时;独立键能避免影响现有文本流式。 + - 备选方案:直接复用普通流式配置并在代码里按路径放大倍数。这个方案会让普通流式和图片流式共享语义,后续难以维护。 + +2. **继续使用上下文 detach,而不是依赖客户端上下文** + - 选择:图片流式请求向上游发起时使用 `context.WithoutCancel` 派生的上下文。 + - 原因:客户端断开时不应自动取消上游请求,否则无法收集最终图片结果,也无法完成图片计费。 + - 备选方案:仍使用 `c.Request.Context()` 并只在写失败后继续 drain。这个方案在客户端取消场景下无法保证上游读取继续进行。 + +3. **只改图片流式路径,不改普通流式路径** + - 选择:`/v1/images/*` 与 `Responses + image_generation` 两条图片流式链路单独处理。 + - 原因:风险最小,避免回归普通文本流式和现有超时配置。 + - 备选方案:统一重构所有流式处理。这个方案范围更大,验证成本更高,不符合本次“尽量少改现有行为”的目标。 + +4. **不新增页面配置** + - 选择:图片流式独立超时默认值写入后端配置,沿用当前配置加载方式。 + - 原因:用户明确要求和当前设置行为统一,不需要额外页面输入项。 + - 备选方案:前端增加图片超时配置项。这个方案会改变现有运维方式,也容易引入误配。 + +## Risks / Trade-offs + +- [Risk] 图片流式继续 drain 上游后,客户端已经断开但服务端仍占用连接与协程资源。→ [Mitigation] 只对图片流式启用更长但仍有限的专用超时,并保持与普通流式同样的 keepalive/超时退出机制。 +- [Risk] 图片流式与普通流式的默认超时不同,运维如果只关注通用配置可能忽略图片专用值。→ [Mitigation] 在配置示例中明确标注图片流式专用默认值和用途。 +- [Risk] 断连后继续读取可能导致日志中出现“客户端断开但最终成功”的状态。→ [Mitigation] 保留现有图片计费结果返回语义,同时让调用方在结果与错误并存时优先使用结果对象。 diff --git a/openspec/changes/image-stream-resilience/proposal.md b/openspec/changes/image-stream-resilience/proposal.md new file mode 100644 index 00000000..2d182c54 --- /dev/null +++ b/openspec/changes/image-stream-resilience/proposal.md @@ -0,0 +1,25 @@ +## Why + +图片流式路径目前没有和普通 Responses 流式一致的断连续写策略,也没有独立于普通流式的超时控制。由于图片生成耗时更长,如果继续沿用普通流式处理方式,客户端断开时容易中断上游读取,影响图片产物收集与按图计费的准确性。 + +## What Changes + +- 为 OpenAI Images API 和 `Responses + image_generation` 流式路径补充独立的上游续读策略,客户端断开后继续 drain 上游,尽量保留最终图片结果和计费结果。 +- 为图片流式路径使用独立的流数据间隔超时与 keepalive 策略,默认比普通流式更长,不新增页面配置项。 +- 保持现有普通流式配置与行为不变,避免影响已经配置好的普通文本分组。 +- 让图片流式路径在超时、断连、写入失败等场景下保持图片计费语义一致。 + +## Capabilities + +### New Capabilities +- `image-stream-resilience`: 图片流式路径的断连续读、独立超时和计费保留能力。 + +### Modified Capabilities +- `image-generation-billing-accounting`: 图片流式结果计数和计费结果的稳定性行为发生改变,但计费契约不变。 + +## Impact + +- 影响 `backend/internal/service/openai_images.go` 和 `backend/internal/service/openai_images_responses.go` 的流式实现。 +- 影响 `backend/internal/config/config.go` 与 `deploy/config.example.yaml` 中图片流式默认值和校验逻辑。 +- 影响 `backend/internal/service/openai_images_test.go`、`backend/internal/config/config_test.go` 以及新增的图片流式稳定性测试。 +- 不新增前端页面设置,不改变普通流式配置项名称和语义。 diff --git a/openspec/changes/image-stream-resilience/specs/image-stream-resilience/spec.md b/openspec/changes/image-stream-resilience/specs/image-stream-resilience/spec.md new file mode 100644 index 00000000..9407ba26 --- /dev/null +++ b/openspec/changes/image-stream-resilience/specs/image-stream-resilience/spec.md @@ -0,0 +1,53 @@ +## ADDED Requirements + +### Requirement: Image stream resilience +The system SHALL keep image generation stream processing active after downstream client disconnects so long as upstream reading can continue, in order to preserve final image outputs and billing results. + +#### Scenario: Images API stream survives downstream disconnect +- **WHEN** `/v1/images/generations` is streamed to a client +- **AND** the downstream writer returns an error before the upstream stream completes +- **THEN** the service continues draining the upstream stream +- **AND** it still counts final image outputs if the upstream later emits them +- **AND** the request can still complete with image billing metadata + +#### Scenario: Responses image tool stream survives downstream disconnect +- **WHEN** a `/v1/responses` request uses `image_generation` and is streamed to a client +- **AND** the downstream writer returns an error before the upstream stream completes +- **THEN** the service continues draining the upstream stream +- **AND** it still counts final image outputs if the upstream later emits them +- **AND** the request can still complete with image billing metadata + +#### Scenario: Client disconnect does not force image stream to downgrade to text billing +- **WHEN** a successful image stream request has already produced final image outputs +- **AND** the downstream client disconnects before the final flush +- **THEN** the request remains billed as an image request +- **AND** the image count is preserved in the forward result + +### Requirement: Image stream timeout isolation +The system SHALL use image-specific streaming timeout settings for image generation stream paths, and these settings SHALL be independent from the ordinary text streaming timeout values. + +#### Scenario: Image stream uses dedicated timeout defaults +- **WHEN** an image generation stream path is executed +- **THEN** it uses the image-specific data interval timeout and keepalive interval defaults +- **AND** it does not rely on the ordinary text stream timeout defaults + +#### Scenario: Ordinary stream settings remain unchanged +- **WHEN** a normal non-image streaming request is executed +- **THEN** the existing ordinary stream timeout configuration and behavior remain unchanged + +#### Scenario: Image stream timeout is longer than ordinary stream timeout +- **WHEN** the image streaming timeout defaults are compared with the ordinary streaming defaults +- **THEN** the image streaming timeout is configured to allow a longer wait window than ordinary text streaming + +### Requirement: Image stream billing consistency +The system SHALL keep the image billing result consistent even when image stream handling uses retries, keepalive writes, or downstream disconnect recovery. + +#### Scenario: Final image count is preserved after reconnect-unsafe downstream failure +- **WHEN** the downstream client disconnects after at least one final image output has been observed upstream +- **THEN** the forward result retains the final image count +- **AND** usage recording can still proceed with image billing metadata + +#### Scenario: Image stream timeout does not silently switch billing mode +- **WHEN** an image stream times out before any final image output is observed +- **THEN** the request is handled as a failed image stream +- **AND** it does not fall back to ordinary text billing semantics diff --git a/openspec/changes/image-stream-resilience/tasks.md b/openspec/changes/image-stream-resilience/tasks.md new file mode 100644 index 00000000..3f401f33 --- /dev/null +++ b/openspec/changes/image-stream-resilience/tasks.md @@ -0,0 +1,20 @@ +## 1. Config and defaults + +- [x] 1.1 Add image-specific stream timeout fields to gateway config. +- [x] 1.2 Register image stream timeout defaults in the config loader. +- [x] 1.3 Add config validation for image stream timeout ranges. +- [x] 1.4 Expose image stream timeout defaults in `deploy/config.example.yaml`. + +## 2. Image stream runtime behavior + +- [x] 2.1 Detach image stream upstream contexts from client cancellation. +- [x] 2.2 Add image-specific data interval timeout handling to `/v1/images/*` streaming. +- [x] 2.3 Add image-specific data interval timeout handling to `Responses + image_generation` streaming. +- [x] 2.4 Preserve upstream draining after downstream write failures in both image stream paths. + +## 3. Tests and verification + +- [x] 3.1 Add config tests for image stream timeout defaults and validation. +- [x] 3.2 Add image streaming disconnect tests for the Images API path. +- [x] 3.3 Add image streaming disconnect tests for the Responses image tool path. +- [x] 3.4 Run focused Go tests for the touched config and image service paths. diff --git a/openspec/config.yaml b/openspec/config.yaml new file mode 100644 index 00000000..392946c6 --- /dev/null +++ b/openspec/config.yaml @@ -0,0 +1,20 @@ +schema: spec-driven + +# Project context (optional) +# This is shown to AI when creating artifacts. +# Add your tech stack, conventions, style guides, domain knowledge, etc. +# Example: +# context: | +# Tech stack: TypeScript, React, Node.js +# We use conventional commits +# Domain: e-commerce platform + +# Per-artifact rules (optional) +# Add custom rules for specific artifacts. +# Example: +# rules: +# proposal: +# - Keep proposals under 500 words +# - Always include a "Non-goals" section +# tasks: +# - Break tasks into chunks of max 2 hours