diff --git a/backend/ent/group.go b/backend/ent/group.go
index 5d9ae2ed..a4f52c73 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -47,6 +47,12 @@ type Group struct {
MonthlyLimitUsd *float64 `json:"monthly_limit_usd,omitempty"`
// DefaultValidityDays holds the value of the "default_validity_days" field.
DefaultValidityDays int `json:"default_validity_days,omitempty"`
+ // 是否允许该分组使用图片生成能力
+ AllowImageGeneration bool `json:"allow_image_generation,omitempty"`
+ // 图片生成是否使用独立倍率;false 表示共享分组有效倍率
+ ImageRateIndependent bool `json:"image_rate_independent,omitempty"`
+ // 图片生成独立倍率,仅 image_rate_independent=true 时生效
+ ImageRateMultiplier float64 `json:"image_rate_multiplier,omitempty"`
// ImagePrice1k holds the value of the "image_price_1k" field.
ImagePrice1k *float64 `json:"image_price_1k,omitempty"`
// ImagePrice2k holds the value of the "image_price_2k" field.
@@ -189,9 +195,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte)
- case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
+ case group.FieldIsExclusive, group.FieldAllowImageGeneration, group.FieldImageRateIndependent, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool)
- case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
+ case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImageRateMultiplier, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
values[i] = new(sql.NullInt64)
@@ -309,6 +315,24 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.DefaultValidityDays = int(value.Int64)
}
+ case group.FieldAllowImageGeneration:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field allow_image_generation", values[i])
+ } else if value.Valid {
+ _m.AllowImageGeneration = value.Bool
+ }
+ case group.FieldImageRateIndependent:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field image_rate_independent", values[i])
+ } else if value.Valid {
+ _m.ImageRateIndependent = value.Bool
+ }
+ case group.FieldImageRateMultiplier:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field image_rate_multiplier", values[i])
+ } else if value.Valid {
+ _m.ImageRateMultiplier = value.Float64
+ }
case group.FieldImagePrice1k:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field image_price_1k", values[i])
@@ -550,6 +574,15 @@ func (_m *Group) String() string {
builder.WriteString("default_validity_days=")
builder.WriteString(fmt.Sprintf("%v", _m.DefaultValidityDays))
builder.WriteString(", ")
+ builder.WriteString("allow_image_generation=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AllowImageGeneration))
+ builder.WriteString(", ")
+ builder.WriteString("image_rate_independent=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ImageRateIndependent))
+ builder.WriteString(", ")
+ builder.WriteString("image_rate_multiplier=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ImageRateMultiplier))
+ builder.WriteString(", ")
if v := _m.ImagePrice1k; v != nil {
builder.WriteString("image_price_1k=")
builder.WriteString(fmt.Sprintf("%v", *v))
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index 24bd9c13..4e9ba6b6 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -44,6 +44,12 @@ const (
FieldMonthlyLimitUsd = "monthly_limit_usd"
// FieldDefaultValidityDays holds the string denoting the default_validity_days field in the database.
FieldDefaultValidityDays = "default_validity_days"
+ // FieldAllowImageGeneration holds the string denoting the allow_image_generation field in the database.
+ FieldAllowImageGeneration = "allow_image_generation"
+ // FieldImageRateIndependent holds the string denoting the image_rate_independent field in the database.
+ FieldImageRateIndependent = "image_rate_independent"
+ // FieldImageRateMultiplier holds the string denoting the image_rate_multiplier field in the database.
+ FieldImageRateMultiplier = "image_rate_multiplier"
// FieldImagePrice1k holds the string denoting the image_price_1k field in the database.
FieldImagePrice1k = "image_price_1k"
// FieldImagePrice2k holds the string denoting the image_price_2k field in the database.
@@ -167,6 +173,9 @@ var Columns = []string{
FieldWeeklyLimitUsd,
FieldMonthlyLimitUsd,
FieldDefaultValidityDays,
+ FieldAllowImageGeneration,
+ FieldImageRateIndependent,
+ FieldImageRateMultiplier,
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
@@ -239,6 +248,12 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
+ // DefaultAllowImageGeneration holds the default value on creation for the "allow_image_generation" field.
+ DefaultAllowImageGeneration bool
+ // DefaultImageRateIndependent holds the default value on creation for the "image_rate_independent" field.
+ DefaultImageRateIndependent bool
+ // DefaultImageRateMultiplier holds the default value on creation for the "image_rate_multiplier" field.
+ DefaultImageRateMultiplier float64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@@ -343,6 +358,21 @@ func ByDefaultValidityDays(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultValidityDays, opts...).ToFunc()
}
+// ByAllowImageGeneration orders the results by the allow_image_generation field.
+func ByAllowImageGeneration(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAllowImageGeneration, opts...).ToFunc()
+}
+
+// ByImageRateIndependent orders the results by the image_rate_independent field.
+func ByImageRateIndependent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldImageRateIndependent, opts...).ToFunc()
+}
+
+// ByImageRateMultiplier orders the results by the image_rate_multiplier field.
+func ByImageRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldImageRateMultiplier, opts...).ToFunc()
+}
+
// ByImagePrice1k orders the results by the image_price_1k field.
func ByImagePrice1k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice1k, opts...).ToFunc()
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index 2814d130..d3223a92 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -125,6 +125,21 @@ func DefaultValidityDays(v int) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultValidityDays, v))
}
+// AllowImageGeneration applies equality check predicate on the "allow_image_generation" field. It's identical to AllowImageGenerationEQ.
+func AllowImageGeneration(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
+}
+
+// ImageRateIndependent applies equality check predicate on the "image_rate_independent" field. It's identical to ImageRateIndependentEQ.
+func ImageRateIndependent(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
+}
+
+// ImageRateMultiplier applies equality check predicate on the "image_rate_multiplier" field. It's identical to ImageRateMultiplierEQ.
+func ImageRateMultiplier(v float64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
+}
+
// ImagePrice1k applies equality check predicate on the "image_price_1k" field. It's identical to ImagePrice1kEQ.
func ImagePrice1k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
@@ -900,6 +915,66 @@ func DefaultValidityDaysLTE(v int) predicate.Group {
return predicate.Group(sql.FieldLTE(FieldDefaultValidityDays, v))
}
+// AllowImageGenerationEQ applies the EQ predicate on the "allow_image_generation" field.
+func AllowImageGenerationEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldAllowImageGeneration, v))
+}
+
+// AllowImageGenerationNEQ applies the NEQ predicate on the "allow_image_generation" field.
+func AllowImageGenerationNEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldAllowImageGeneration, v))
+}
+
+// ImageRateIndependentEQ applies the EQ predicate on the "image_rate_independent" field.
+func ImageRateIndependentEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldImageRateIndependent, v))
+}
+
+// ImageRateIndependentNEQ applies the NEQ predicate on the "image_rate_independent" field.
+func ImageRateIndependentNEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldImageRateIndependent, v))
+}
+
+// ImageRateMultiplierEQ applies the EQ predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierEQ(v float64) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldImageRateMultiplier, v))
+}
+
+// ImageRateMultiplierNEQ applies the NEQ predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierNEQ(v float64) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldImageRateMultiplier, v))
+}
+
+// ImageRateMultiplierIn applies the In predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierIn(vs ...float64) predicate.Group {
+ return predicate.Group(sql.FieldIn(FieldImageRateMultiplier, vs...))
+}
+
+// ImageRateMultiplierNotIn applies the NotIn predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierNotIn(vs ...float64) predicate.Group {
+ return predicate.Group(sql.FieldNotIn(FieldImageRateMultiplier, vs...))
+}
+
+// ImageRateMultiplierGT applies the GT predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierGT(v float64) predicate.Group {
+ return predicate.Group(sql.FieldGT(FieldImageRateMultiplier, v))
+}
+
+// ImageRateMultiplierGTE applies the GTE predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierGTE(v float64) predicate.Group {
+ return predicate.Group(sql.FieldGTE(FieldImageRateMultiplier, v))
+}
+
+// ImageRateMultiplierLT applies the LT predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierLT(v float64) predicate.Group {
+ return predicate.Group(sql.FieldLT(FieldImageRateMultiplier, v))
+}
+
+// ImageRateMultiplierLTE applies the LTE predicate on the "image_rate_multiplier" field.
+func ImageRateMultiplierLTE(v float64) predicate.Group {
+ return predicate.Group(sql.FieldLTE(FieldImageRateMultiplier, v))
+}
+
// ImagePrice1kEQ applies the EQ predicate on the "image_price_1k" field.
func ImagePrice1kEQ(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice1k, v))
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index 20ea0a0f..44b905bd 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -217,6 +217,48 @@ func (_c *GroupCreate) SetNillableDefaultValidityDays(v *int) *GroupCreate {
return _c
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (_c *GroupCreate) SetAllowImageGeneration(v bool) *GroupCreate {
+ _c.mutation.SetAllowImageGeneration(v)
+ return _c
+}
+
+// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableAllowImageGeneration(v *bool) *GroupCreate {
+ if v != nil {
+ _c.SetAllowImageGeneration(*v)
+ }
+ return _c
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (_c *GroupCreate) SetImageRateIndependent(v bool) *GroupCreate {
+ _c.mutation.SetImageRateIndependent(v)
+ return _c
+}
+
+// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableImageRateIndependent(v *bool) *GroupCreate {
+ if v != nil {
+ _c.SetImageRateIndependent(*v)
+ }
+ return _c
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (_c *GroupCreate) SetImageRateMultiplier(v float64) *GroupCreate {
+ _c.mutation.SetImageRateMultiplier(v)
+ return _c
+}
+
+// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableImageRateMultiplier(v *float64) *GroupCreate {
+ if v != nil {
+ _c.SetImageRateMultiplier(*v)
+ }
+ return _c
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (_c *GroupCreate) SetImagePrice1k(v float64) *GroupCreate {
_c.mutation.SetImagePrice1k(v)
@@ -604,6 +646,18 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
+ if _, ok := _c.mutation.AllowImageGeneration(); !ok {
+ v := group.DefaultAllowImageGeneration
+ _c.mutation.SetAllowImageGeneration(v)
+ }
+ if _, ok := _c.mutation.ImageRateIndependent(); !ok {
+ v := group.DefaultImageRateIndependent
+ _c.mutation.SetImageRateIndependent(v)
+ }
+ if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
+ v := group.DefaultImageRateMultiplier
+ _c.mutation.SetImageRateMultiplier(v)
+ }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
@@ -700,6 +754,15 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
+ if _, ok := _c.mutation.AllowImageGeneration(); !ok {
+ return &ValidationError{Name: "allow_image_generation", err: errors.New(`ent: missing required field "Group.allow_image_generation"`)}
+ }
+ if _, ok := _c.mutation.ImageRateIndependent(); !ok {
+ return &ValidationError{Name: "image_rate_independent", err: errors.New(`ent: missing required field "Group.image_rate_independent"`)}
+ }
+ if _, ok := _c.mutation.ImageRateMultiplier(); !ok {
+ return &ValidationError{Name: "image_rate_multiplier", err: errors.New(`ent: missing required field "Group.image_rate_multiplier"`)}
+ }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
@@ -821,6 +884,18 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultValidityDays, field.TypeInt, value)
_node.DefaultValidityDays = value
}
+ if value, ok := _c.mutation.AllowImageGeneration(); ok {
+ _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
+ _node.AllowImageGeneration = value
+ }
+ if value, ok := _c.mutation.ImageRateIndependent(); ok {
+ _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
+ _node.ImageRateIndependent = value
+ }
+ if value, ok := _c.mutation.ImageRateMultiplier(); ok {
+ _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
+ _node.ImageRateMultiplier = value
+ }
if value, ok := _c.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
_node.ImagePrice1k = &value
@@ -1261,6 +1336,48 @@ func (u *GroupUpsert) AddDefaultValidityDays(v int) *GroupUpsert {
return u
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (u *GroupUpsert) SetAllowImageGeneration(v bool) *GroupUpsert {
+ u.Set(group.FieldAllowImageGeneration, v)
+ return u
+}
+
+// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateAllowImageGeneration() *GroupUpsert {
+ u.SetExcluded(group.FieldAllowImageGeneration)
+ return u
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (u *GroupUpsert) SetImageRateIndependent(v bool) *GroupUpsert {
+ u.Set(group.FieldImageRateIndependent, v)
+ return u
+}
+
+// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateImageRateIndependent() *GroupUpsert {
+ u.SetExcluded(group.FieldImageRateIndependent)
+ return u
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (u *GroupUpsert) SetImageRateMultiplier(v float64) *GroupUpsert {
+ u.Set(group.FieldImageRateMultiplier, v)
+ return u
+}
+
+// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateImageRateMultiplier() *GroupUpsert {
+ u.SetExcluded(group.FieldImageRateMultiplier)
+ return u
+}
+
+// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
+func (u *GroupUpsert) AddImageRateMultiplier(v float64) *GroupUpsert {
+ u.Add(group.FieldImageRateMultiplier, v)
+ return u
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsert) SetImagePrice1k(v float64) *GroupUpsert {
u.Set(group.FieldImagePrice1k, v)
@@ -1840,6 +1957,55 @@ func (u *GroupUpsertOne) UpdateDefaultValidityDays() *GroupUpsertOne {
})
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (u *GroupUpsertOne) SetAllowImageGeneration(v bool) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetAllowImageGeneration(v)
+ })
+}
+
+// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateAllowImageGeneration() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateAllowImageGeneration()
+ })
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (u *GroupUpsertOne) SetImageRateIndependent(v bool) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetImageRateIndependent(v)
+ })
+}
+
+// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateImageRateIndependent() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateImageRateIndependent()
+ })
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (u *GroupUpsertOne) SetImageRateMultiplier(v float64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetImageRateMultiplier(v)
+ })
+}
+
+// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
+func (u *GroupUpsertOne) AddImageRateMultiplier(v float64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddImageRateMultiplier(v)
+ })
+}
+
+// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateImageRateMultiplier() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateImageRateMultiplier()
+ })
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertOne) SetImagePrice1k(v float64) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2632,6 +2798,55 @@ func (u *GroupUpsertBulk) UpdateDefaultValidityDays() *GroupUpsertBulk {
})
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (u *GroupUpsertBulk) SetAllowImageGeneration(v bool) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetAllowImageGeneration(v)
+ })
+}
+
+// UpdateAllowImageGeneration sets the "allow_image_generation" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateAllowImageGeneration() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateAllowImageGeneration()
+ })
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (u *GroupUpsertBulk) SetImageRateIndependent(v bool) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetImageRateIndependent(v)
+ })
+}
+
+// UpdateImageRateIndependent sets the "image_rate_independent" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateImageRateIndependent() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateImageRateIndependent()
+ })
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (u *GroupUpsertBulk) SetImageRateMultiplier(v float64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetImageRateMultiplier(v)
+ })
+}
+
+// AddImageRateMultiplier adds v to the "image_rate_multiplier" field.
+func (u *GroupUpsertBulk) AddImageRateMultiplier(v float64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddImageRateMultiplier(v)
+ })
+}
+
+// UpdateImageRateMultiplier sets the "image_rate_multiplier" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateImageRateMultiplier() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateImageRateMultiplier()
+ })
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (u *GroupUpsertBulk) SetImagePrice1k(v float64) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index cc14f897..fe55982c 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -275,6 +275,55 @@ func (_u *GroupUpdate) AddDefaultValidityDays(v int) *GroupUpdate {
return _u
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (_u *GroupUpdate) SetAllowImageGeneration(v bool) *GroupUpdate {
+ _u.mutation.SetAllowImageGeneration(v)
+ return _u
+}
+
+// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableAllowImageGeneration(v *bool) *GroupUpdate {
+ if v != nil {
+ _u.SetAllowImageGeneration(*v)
+ }
+ return _u
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (_u *GroupUpdate) SetImageRateIndependent(v bool) *GroupUpdate {
+ _u.mutation.SetImageRateIndependent(v)
+ return _u
+}
+
+// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableImageRateIndependent(v *bool) *GroupUpdate {
+ if v != nil {
+ _u.SetImageRateIndependent(*v)
+ }
+ return _u
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (_u *GroupUpdate) SetImageRateMultiplier(v float64) *GroupUpdate {
+ _u.mutation.ResetImageRateMultiplier()
+ _u.mutation.SetImageRateMultiplier(v)
+ return _u
+}
+
+// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableImageRateMultiplier(v *float64) *GroupUpdate {
+ if v != nil {
+ _u.SetImageRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
+func (_u *GroupUpdate) AddImageRateMultiplier(v float64) *GroupUpdate {
+ _u.mutation.AddImageRateMultiplier(v)
+ return _u
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdate) SetImagePrice1k(v float64) *GroupUpdate {
_u.mutation.ResetImagePrice1k()
@@ -962,6 +1011,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
+ if value, ok := _u.mutation.AllowImageGeneration(); ok {
+ _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.ImageRateIndependent(); ok {
+ _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.ImageRateMultiplier(); ok {
+ _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
+ _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
+ }
if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
}
@@ -1610,6 +1671,55 @@ func (_u *GroupUpdateOne) AddDefaultValidityDays(v int) *GroupUpdateOne {
return _u
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (_u *GroupUpdateOne) SetAllowImageGeneration(v bool) *GroupUpdateOne {
+ _u.mutation.SetAllowImageGeneration(v)
+ return _u
+}
+
+// SetNillableAllowImageGeneration sets the "allow_image_generation" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableAllowImageGeneration(v *bool) *GroupUpdateOne {
+ if v != nil {
+ _u.SetAllowImageGeneration(*v)
+ }
+ return _u
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (_u *GroupUpdateOne) SetImageRateIndependent(v bool) *GroupUpdateOne {
+ _u.mutation.SetImageRateIndependent(v)
+ return _u
+}
+
+// SetNillableImageRateIndependent sets the "image_rate_independent" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableImageRateIndependent(v *bool) *GroupUpdateOne {
+ if v != nil {
+ _u.SetImageRateIndependent(*v)
+ }
+ return _u
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (_u *GroupUpdateOne) SetImageRateMultiplier(v float64) *GroupUpdateOne {
+ _u.mutation.ResetImageRateMultiplier()
+ _u.mutation.SetImageRateMultiplier(v)
+ return _u
+}
+
+// SetNillableImageRateMultiplier sets the "image_rate_multiplier" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableImageRateMultiplier(v *float64) *GroupUpdateOne {
+ if v != nil {
+ _u.SetImageRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddImageRateMultiplier adds value to the "image_rate_multiplier" field.
+func (_u *GroupUpdateOne) AddImageRateMultiplier(v float64) *GroupUpdateOne {
+ _u.mutation.AddImageRateMultiplier(v)
+ return _u
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (_u *GroupUpdateOne) SetImagePrice1k(v float64) *GroupUpdateOne {
_u.mutation.ResetImagePrice1k()
@@ -2327,6 +2437,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.AddedDefaultValidityDays(); ok {
_spec.AddField(group.FieldDefaultValidityDays, field.TypeInt, value)
}
+ if value, ok := _u.mutation.AllowImageGeneration(); ok {
+ _spec.SetField(group.FieldAllowImageGeneration, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.ImageRateIndependent(); ok {
+ _spec.SetField(group.FieldImageRateIndependent, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.ImageRateMultiplier(); ok {
+ _spec.SetField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedImageRateMultiplier(); ok {
+ _spec.AddField(group.FieldImageRateMultiplier, field.TypeFloat64, value)
+ }
if value, ok := _u.mutation.ImagePrice1k(); ok {
_spec.SetField(group.FieldImagePrice1k, field.TypeFloat64, value)
}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 178ae170..525ff092 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -638,6 +638,9 @@ var (
{Name: "weekly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "monthly_limit_usd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "default_validity_days", Type: field.TypeInt, Default: 30},
+ {Name: "allow_image_generation", Type: field.TypeBool, Default: false},
+ {Name: "image_rate_independent", Type: field.TypeBool, Default: false},
+ {Name: "image_rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
@@ -690,7 +693,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
- Columns: []*schema.Column{GroupsColumns[25]},
+ Columns: []*schema.Column{GroupsColumns[28]},
},
},
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index d616e4ae..13f6193d 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -14764,6 +14764,10 @@ type GroupMutation struct {
addmonthly_limit_usd *float64
default_validity_days *int
adddefault_validity_days *int
+ allow_image_generation *bool
+ image_rate_independent *bool
+ image_rate_multiplier *float64
+ addimage_rate_multiplier *float64
image_price_1k *float64
addimage_price_1k *float64
image_price_2k *float64
@@ -15583,6 +15587,134 @@ func (m *GroupMutation) ResetDefaultValidityDays() {
m.adddefault_validity_days = nil
}
+// SetAllowImageGeneration sets the "allow_image_generation" field.
+func (m *GroupMutation) SetAllowImageGeneration(b bool) {
+ m.allow_image_generation = &b
+}
+
+// AllowImageGeneration returns the value of the "allow_image_generation" field in the mutation.
+func (m *GroupMutation) AllowImageGeneration() (r bool, exists bool) {
+ v := m.allow_image_generation
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAllowImageGeneration returns the old "allow_image_generation" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldAllowImageGeneration(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAllowImageGeneration is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAllowImageGeneration requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAllowImageGeneration: %w", err)
+ }
+ return oldValue.AllowImageGeneration, nil
+}
+
+// ResetAllowImageGeneration resets all changes to the "allow_image_generation" field.
+func (m *GroupMutation) ResetAllowImageGeneration() {
+ m.allow_image_generation = nil
+}
+
+// SetImageRateIndependent sets the "image_rate_independent" field.
+func (m *GroupMutation) SetImageRateIndependent(b bool) {
+ m.image_rate_independent = &b
+}
+
+// ImageRateIndependent returns the value of the "image_rate_independent" field in the mutation.
+func (m *GroupMutation) ImageRateIndependent() (r bool, exists bool) {
+ v := m.image_rate_independent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageRateIndependent returns the old "image_rate_independent" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldImageRateIndependent(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageRateIndependent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageRateIndependent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageRateIndependent: %w", err)
+ }
+ return oldValue.ImageRateIndependent, nil
+}
+
+// ResetImageRateIndependent resets all changes to the "image_rate_independent" field.
+func (m *GroupMutation) ResetImageRateIndependent() {
+ m.image_rate_independent = nil
+}
+
+// SetImageRateMultiplier sets the "image_rate_multiplier" field.
+func (m *GroupMutation) SetImageRateMultiplier(f float64) {
+ m.image_rate_multiplier = &f
+ m.addimage_rate_multiplier = nil
+}
+
+// ImageRateMultiplier returns the value of the "image_rate_multiplier" field in the mutation.
+func (m *GroupMutation) ImageRateMultiplier() (r float64, exists bool) {
+ v := m.image_rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageRateMultiplier returns the old "image_rate_multiplier" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldImageRateMultiplier(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageRateMultiplier is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageRateMultiplier requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageRateMultiplier: %w", err)
+ }
+ return oldValue.ImageRateMultiplier, nil
+}
+
+// AddImageRateMultiplier adds f to the "image_rate_multiplier" field.
+func (m *GroupMutation) AddImageRateMultiplier(f float64) {
+ if m.addimage_rate_multiplier != nil {
+ *m.addimage_rate_multiplier += f
+ } else {
+ m.addimage_rate_multiplier = &f
+ }
+}
+
+// AddedImageRateMultiplier returns the value that was added to the "image_rate_multiplier" field in this mutation.
+func (m *GroupMutation) AddedImageRateMultiplier() (r float64, exists bool) {
+ v := m.addimage_rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetImageRateMultiplier resets all changes to the "image_rate_multiplier" field.
+func (m *GroupMutation) ResetImageRateMultiplier() {
+ m.image_rate_multiplier = nil
+ m.addimage_rate_multiplier = nil
+}
+
// SetImagePrice1k sets the "image_price_1k" field.
func (m *GroupMutation) SetImagePrice1k(f float64) {
m.image_price_1k = &f
@@ -16791,7 +16923,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 31)
+ fields := make([]string, 0, 34)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -16834,6 +16966,15 @@ func (m *GroupMutation) Fields() []string {
if m.default_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays)
}
+ if m.allow_image_generation != nil {
+ fields = append(fields, group.FieldAllowImageGeneration)
+ }
+ if m.image_rate_independent != nil {
+ fields = append(fields, group.FieldImageRateIndependent)
+ }
+ if m.image_rate_multiplier != nil {
+ fields = append(fields, group.FieldImageRateMultiplier)
+ }
if m.image_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k)
}
@@ -16921,6 +17062,12 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.MonthlyLimitUsd()
case group.FieldDefaultValidityDays:
return m.DefaultValidityDays()
+ case group.FieldAllowImageGeneration:
+ return m.AllowImageGeneration()
+ case group.FieldImageRateIndependent:
+ return m.ImageRateIndependent()
+ case group.FieldImageRateMultiplier:
+ return m.ImageRateMultiplier()
case group.FieldImagePrice1k:
return m.ImagePrice1k()
case group.FieldImagePrice2k:
@@ -16992,6 +17139,12 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldMonthlyLimitUsd(ctx)
case group.FieldDefaultValidityDays:
return m.OldDefaultValidityDays(ctx)
+ case group.FieldAllowImageGeneration:
+ return m.OldAllowImageGeneration(ctx)
+ case group.FieldImageRateIndependent:
+ return m.OldImageRateIndependent(ctx)
+ case group.FieldImageRateMultiplier:
+ return m.OldImageRateMultiplier(ctx)
case group.FieldImagePrice1k:
return m.OldImagePrice1k(ctx)
case group.FieldImagePrice2k:
@@ -17133,6 +17286,27 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetDefaultValidityDays(v)
return nil
+ case group.FieldAllowImageGeneration:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowImageGeneration(v)
+ return nil
+ case group.FieldImageRateIndependent:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageRateIndependent(v)
+ return nil
+ case group.FieldImageRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageRateMultiplier(v)
+ return nil
case group.FieldImagePrice1k:
v, ok := value.(float64)
if !ok {
@@ -17275,6 +17449,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.adddefault_validity_days != nil {
fields = append(fields, group.FieldDefaultValidityDays)
}
+ if m.addimage_rate_multiplier != nil {
+ fields = append(fields, group.FieldImageRateMultiplier)
+ }
if m.addimage_price_1k != nil {
fields = append(fields, group.FieldImagePrice1k)
}
@@ -17314,6 +17491,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedMonthlyLimitUsd()
case group.FieldDefaultValidityDays:
return m.AddedDefaultValidityDays()
+ case group.FieldImageRateMultiplier:
+ return m.AddedImageRateMultiplier()
case group.FieldImagePrice1k:
return m.AddedImagePrice1k()
case group.FieldImagePrice2k:
@@ -17372,6 +17551,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddDefaultValidityDays(v)
return nil
+ case group.FieldImageRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddImageRateMultiplier(v)
+ return nil
case group.FieldImagePrice1k:
v, ok := value.(float64)
if !ok {
@@ -17559,6 +17745,15 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultValidityDays:
m.ResetDefaultValidityDays()
return nil
+ case group.FieldAllowImageGeneration:
+ m.ResetAllowImageGeneration()
+ return nil
+ case group.FieldImageRateIndependent:
+ m.ResetImageRateIndependent()
+ return nil
+ case group.FieldImageRateMultiplier:
+ m.ResetImageRateMultiplier()
+ return nil
case group.FieldImagePrice1k:
m.ResetImagePrice1k()
return nil
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 6b344a55..a282d9ba 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -803,50 +803,62 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
+ // groupDescAllowImageGeneration is the schema descriptor for allow_image_generation field.
+ groupDescAllowImageGeneration := groupFields[11].Descriptor()
+ // group.DefaultAllowImageGeneration holds the default value on creation for the allow_image_generation field.
+ group.DefaultAllowImageGeneration = groupDescAllowImageGeneration.Default.(bool)
+ // groupDescImageRateIndependent is the schema descriptor for image_rate_independent field.
+ groupDescImageRateIndependent := groupFields[12].Descriptor()
+ // group.DefaultImageRateIndependent holds the default value on creation for the image_rate_independent field.
+ group.DefaultImageRateIndependent = groupDescImageRateIndependent.Default.(bool)
+ // groupDescImageRateMultiplier is the schema descriptor for image_rate_multiplier field.
+ groupDescImageRateMultiplier := groupFields[13].Descriptor()
+ // group.DefaultImageRateMultiplier holds the default value on creation for the image_rate_multiplier field.
+ group.DefaultImageRateMultiplier = groupDescImageRateMultiplier.Default.(float64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
- groupDescClaudeCodeOnly := groupFields[14].Descriptor()
+ groupDescClaudeCodeOnly := groupFields[17].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
- groupDescModelRoutingEnabled := groupFields[18].Descriptor()
+ groupDescModelRoutingEnabled := groupFields[21].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
- groupDescMcpXMLInject := groupFields[19].Descriptor()
+ groupDescMcpXMLInject := groupFields[22].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
- groupDescSupportedModelScopes := groupFields[20].Descriptor()
+ groupDescSupportedModelScopes := groupFields[23].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
- groupDescSortOrder := groupFields[21].Descriptor()
+ groupDescSortOrder := groupFields[24].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
- groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
+ groupDescAllowMessagesDispatch := groupFields[25].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
- groupDescRequireOauthOnly := groupFields[23].Descriptor()
+ groupDescRequireOauthOnly := groupFields[26].Descriptor()
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
- groupDescRequirePrivacySet := groupFields[24].Descriptor()
+ groupDescRequirePrivacySet := groupFields[27].Descriptor()
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
- groupDescDefaultMappedModel := groupFields[25].Descriptor()
+ groupDescDefaultMappedModel := groupFields[28].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
// groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
- groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
+ groupDescMessagesDispatchModelConfig := groupFields[29].Descriptor()
// group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
// groupDescRpmLimit is the schema descriptor for rpm_limit field.
- groupDescRpmLimit := groupFields[27].Descriptor()
+ groupDescRpmLimit := groupFields[30].Descriptor()
// group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index 11f38d66..d47e8710 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -74,6 +74,16 @@ func (Group) Fields() []ent.Field {
Default(30),
// 图片生成计费配置(antigravity 和 gemini 平台使用)
+ field.Bool("allow_image_generation").
+ Default(false).
+ Comment("是否允许该分组使用图片生成能力"),
+ field.Bool("image_rate_independent").
+ Default(false).
+ Comment("图片生成是否使用独立倍率;false 表示共享分组有效倍率"),
+ field.Float("image_rate_multiplier").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
+ Default(1.0).
+ Comment("图片生成独立倍率,仅 image_rate_independent=true 时生效"),
field.Float("image_price_1k").
Optional().
Nillable().
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 87263db0..e3dc2109 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -575,6 +575,24 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
+type ImageConcurrencyConfig struct {
+ // Enabled: 是否启用图片生成独立并发限制,默认关闭以保持现有行为
+ Enabled bool `mapstructure:"enabled"`
+ // MaxConcurrentRequests: 当前进程允许同时处理的图片生成请求数,0表示不限制
+ MaxConcurrentRequests int `mapstructure:"max_concurrent_requests"`
+ // OverflowMode: 图片并发达到上限后的处理方式:reject/wait
+ OverflowMode string `mapstructure:"overflow_mode"`
+ // WaitTimeoutSeconds: overflow_mode=wait 时等待图片并发槽位的超时时间(秒)
+ WaitTimeoutSeconds int `mapstructure:"wait_timeout_seconds"`
+ // MaxWaitingRequests: overflow_mode=wait 时当前进程允许排队等待的图片请求数
+ MaxWaitingRequests int `mapstructure:"max_waiting_requests"`
+}
+
+const (
+ ImageConcurrencyOverflowModeReject = "reject"
+ ImageConcurrencyOverflowModeWait = "wait"
+)
+
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时
@@ -604,6 +622,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
+ // ImageConcurrency: 图片生成独立并发限制配置(默认关闭)
+ ImageConcurrency ImageConcurrencyConfig `mapstructure:"image_concurrency"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -635,6 +655,10 @@ type GatewayConfig struct {
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
// StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"`
+ // ImageStreamDataIntervalTimeout: 图片流数据间隔超时(秒),0表示禁用
+ ImageStreamDataIntervalTimeout int `mapstructure:"image_stream_data_interval_timeout"`
+ // ImageStreamKeepaliveInterval: 图片流式 keepalive 间隔(秒),0表示禁用
+ ImageStreamKeepaliveInterval int `mapstructure:"image_stream_keepalive_interval"`
// MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值)
MaxLineSize int `mapstructure:"max_line_size"`
@@ -1672,6 +1696,11 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
+ viper.SetDefault("gateway.image_concurrency.enabled", false)
+ viper.SetDefault("gateway.image_concurrency.max_concurrent_requests", 0)
+ viper.SetDefault("gateway.image_concurrency.overflow_mode", ImageConcurrencyOverflowModeReject)
+ viper.SetDefault("gateway.image_concurrency.wait_timeout_seconds", 30)
+ viper.SetDefault("gateway.image_concurrency.max_waiting_requests", 100)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
@@ -1689,6 +1718,8 @@ func setDefaults() {
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
+ viper.SetDefault("gateway.image_stream_data_interval_timeout", 900)
+ viper.SetDefault("gateway.image_stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
@@ -2239,6 +2270,21 @@ func (c *Config) Validate() error {
ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy)
}
}
+ if c.Gateway.ImageConcurrency.MaxConcurrentRequests < 0 {
+ return fmt.Errorf("gateway.image_concurrency.max_concurrent_requests must be non-negative")
+ }
+ switch strings.TrimSpace(c.Gateway.ImageConcurrency.OverflowMode) {
+ case "", ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait:
+ default:
+ return fmt.Errorf("gateway.image_concurrency.overflow_mode must be one of: %s/%s",
+ ImageConcurrencyOverflowModeReject, ImageConcurrencyOverflowModeWait)
+ }
+ if c.Gateway.ImageConcurrency.WaitTimeoutSeconds < 0 {
+ return fmt.Errorf("gateway.image_concurrency.wait_timeout_seconds must be non-negative")
+ }
+ if c.Gateway.ImageConcurrency.MaxWaitingRequests < 0 {
+ return fmt.Errorf("gateway.image_concurrency.max_waiting_requests must be non-negative")
+ }
if c.Gateway.MaxIdleConns <= 0 {
return fmt.Errorf("gateway.max_idle_conns must be positive")
}
@@ -2277,6 +2323,20 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
+ if c.Gateway.ImageStreamDataIntervalTimeout < 0 {
+ return fmt.Errorf("gateway.image_stream_data_interval_timeout must be non-negative")
+ }
+ if c.Gateway.ImageStreamDataIntervalTimeout != 0 &&
+ (c.Gateway.ImageStreamDataIntervalTimeout < 60 || c.Gateway.ImageStreamDataIntervalTimeout > 1800) {
+ return fmt.Errorf("gateway.image_stream_data_interval_timeout must be 0 or between 60-1800 seconds")
+ }
+ if c.Gateway.ImageStreamKeepaliveInterval < 0 {
+ return fmt.Errorf("gateway.image_stream_keepalive_interval must be non-negative")
+ }
+ if c.Gateway.ImageStreamKeepaliveInterval != 0 &&
+ (c.Gateway.ImageStreamKeepaliveInterval < 5 || c.Gateway.ImageStreamKeepaliveInterval > 60) {
+ return fmt.Errorf("gateway.image_stream_keepalive_interval must be 0 or between 5-60 seconds")
+ }
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index 6ba86aa1..a47de2f8 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -1282,6 +1282,46 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
wantErr: "gateway.stream_data_interval_timeout must be non-negative",
},
+ {
+ name: "gateway image stream keepalive range",
+ mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = 4 },
+ wantErr: "gateway.image_stream_keepalive_interval",
+ },
+ {
+ name: "gateway image stream keepalive negative",
+ mutate: func(c *Config) { c.Gateway.ImageStreamKeepaliveInterval = -1 },
+ wantErr: "gateway.image_stream_keepalive_interval must be non-negative",
+ },
+ {
+ name: "gateway image stream data interval range",
+ mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = 30 },
+ wantErr: "gateway.image_stream_data_interval_timeout",
+ },
+ {
+ name: "gateway image stream data interval negative",
+ mutate: func(c *Config) { c.Gateway.ImageStreamDataIntervalTimeout = -1 },
+ wantErr: "gateway.image_stream_data_interval_timeout must be non-negative",
+ },
+ {
+ name: "gateway image concurrency max negative",
+ mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxConcurrentRequests = -1 },
+ wantErr: "gateway.image_concurrency.max_concurrent_requests must be non-negative",
+ },
+ {
+ name: "gateway image concurrency overflow mode invalid",
+ mutate: func(c *Config) { c.Gateway.ImageConcurrency.OverflowMode = "queue" },
+ wantErr: "gateway.image_concurrency.overflow_mode",
+ },
+ {
+ name: "gateway image concurrency wait timeout negative",
+ mutate: func(c *Config) { c.Gateway.ImageConcurrency.WaitTimeoutSeconds = -1 },
+ wantErr: "gateway.image_concurrency.wait_timeout_seconds must be non-negative",
+ },
+ {
+ name: "gateway image concurrency max waiting negative",
+ mutate: func(c *Config) { c.Gateway.ImageConcurrency.MaxWaitingRequests = -1 },
+ wantErr: "gateway.image_concurrency.max_waiting_requests must be non-negative",
+ },
{
name: "gateway max line size",
mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
@@ -1754,3 +1794,41 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds)
}
}
+
+func TestLoad_DefaultGatewayImageStreamConfig(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+ if cfg.Gateway.StreamDataIntervalTimeout != 180 {
+ t.Fatalf("stream_data_interval_timeout = %d, want 180", cfg.Gateway.StreamDataIntervalTimeout)
+ }
+ if cfg.Gateway.StreamKeepaliveInterval != 10 {
+ t.Fatalf("stream_keepalive_interval = %d, want 10", cfg.Gateway.StreamKeepaliveInterval)
+ }
+ if cfg.Gateway.ImageStreamDataIntervalTimeout != 900 {
+ t.Fatalf("image_stream_data_interval_timeout = %d, want 900", cfg.Gateway.ImageStreamDataIntervalTimeout)
+ }
+ if cfg.Gateway.ImageStreamKeepaliveInterval != 10 {
+ t.Fatalf("image_stream_keepalive_interval = %d, want 10", cfg.Gateway.ImageStreamKeepaliveInterval)
+ }
+ if cfg.Gateway.ImageConcurrency.Enabled {
+ t.Fatalf("image_concurrency.enabled = true, want false")
+ }
+ if cfg.Gateway.ImageConcurrency.MaxConcurrentRequests != 0 {
+ t.Fatalf("image_concurrency.max_concurrent_requests = %d, want 0", cfg.Gateway.ImageConcurrency.MaxConcurrentRequests)
+ }
+ if cfg.Gateway.ImageConcurrency.OverflowMode != ImageConcurrencyOverflowModeReject {
+ t.Fatalf("image_concurrency.overflow_mode = %q, want %q", cfg.Gateway.ImageConcurrency.OverflowMode, ImageConcurrencyOverflowModeReject)
+ }
+ if cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds != 30 {
+ t.Fatalf("image_concurrency.wait_timeout_seconds = %d, want 30", cfg.Gateway.ImageConcurrency.WaitTimeoutSeconds)
+ }
+ if cfg.Gateway.ImageConcurrency.MaxWaitingRequests != 100 {
+ t.Fatalf("image_concurrency.max_waiting_requests = %d, want 100", cfg.Gateway.ImageConcurrency.MaxWaitingRequests)
+ }
+ if cfg.Gateway.ImageStreamDataIntervalTimeout <= cfg.Gateway.StreamDataIntervalTimeout {
+ t.Fatalf("image stream timeout = %d, want greater than ordinary stream timeout %d", cfg.Gateway.ImageStreamDataIntervalTimeout, cfg.Gateway.StreamDataIntervalTimeout)
+ }
+}
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index 65e5ec78..3667bbcd 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -92,6 +92,9 @@ type CreateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
+ AllowImageGeneration bool `json:"allow_image_generation"`
+ ImageRateIndependent bool `json:"image_rate_independent"`
+ ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
@@ -129,6 +132,9 @@ type UpdateGroupRequest struct {
WeeklyLimitUSD optionalLimitField `json:"weekly_limit_usd"`
MonthlyLimitUSD optionalLimitField `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
+ AllowImageGeneration *bool `json:"allow_image_generation"`
+ ImageRateIndependent *bool `json:"image_rate_independent"`
+ ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
@@ -251,6 +257,9 @@ func (h *GroupHandler) Create(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
+ AllowImageGeneration: req.AllowImageGeneration,
+ ImageRateIndependent: req.ImageRateIndependent,
+ ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
@@ -303,6 +312,9 @@ func (h *GroupHandler) Update(c *gin.Context) {
DailyLimitUSD: req.DailyLimitUSD.ToServiceInput(),
WeeklyLimitUSD: req.WeeklyLimitUSD.ToServiceInput(),
MonthlyLimitUSD: req.MonthlyLimitUSD.ToServiceInput(),
+ AllowImageGeneration: req.AllowImageGeneration,
+ ImageRateIndependent: req.ImageRateIndependent,
+ ImageRateMultiplier: req.ImageRateMultiplier,
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index f7503c2e..2559b112 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -176,6 +176,9 @@ func groupFromServiceBase(g *service.Group) Group {
DailyLimitUSD: g.DailyLimitUSD,
WeeklyLimitUSD: g.WeeklyLimitUSD,
MonthlyLimitUSD: g.MonthlyLimitUSD,
+ AllowImageGeneration: g.AllowImageGeneration,
+ ImageRateIndependent: g.ImageRateIndependent,
+ ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 5cc2f8e4..e15a916e 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -94,9 +94,12 @@ type Group struct {
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64 `json:"image_price_1k"`
- ImagePrice2K *float64 `json:"image_price_2k"`
- ImagePrice4K *float64 `json:"image_price_4k"`
+ AllowImageGeneration bool `json:"allow_image_generation"`
+ ImageRateIndependent bool `json:"image_rate_independent"`
+ ImageRateMultiplier float64 `json:"image_rate_multiplier"`
+ ImagePrice1K *float64 `json:"image_price_1k"`
+ ImagePrice2K *float64 `json:"image_price_2k"`
+ ImagePrice4K *float64 `json:"image_price_4k"`
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
diff --git a/backend/internal/handler/image_concurrency_limiter.go b/backend/internal/handler/image_concurrency_limiter.go
new file mode 100644
index 00000000..6e7cbb67
--- /dev/null
+++ b/backend/internal/handler/image_concurrency_limiter.go
@@ -0,0 +1,126 @@
+package handler
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+type imageConcurrencyLimiter struct {
+ mu sync.Mutex
+ notify chan struct{}
+ limit int
+ active int
+ waiting int
+ enabled bool
+}
+
+func (l *imageConcurrencyLimiter) TryAcquire(enabled bool, limit int) (func(), bool) {
+ return l.acquire(context.Background(), enabled, limit, false, 0, 0)
+}
+
+func (l *imageConcurrencyLimiter) Acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
+ return l.acquire(ctx, enabled, limit, wait, timeout, maxWaiting)
+}
+
+func (l *imageConcurrencyLimiter) acquire(ctx context.Context, enabled bool, limit int, wait bool, timeout time.Duration, maxWaiting int) (func(), bool) {
+ if !enabled || limit <= 0 {
+ return nil, true
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if wait {
+ if timeout <= 0 {
+ return nil, false
+ }
+ waitCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ ctx = waitCtx
+ }
+ if maxWaiting < 0 {
+ maxWaiting = 0
+ }
+ for {
+ release, acquired, waitRelease, notify := l.tryAcquireLocked(enabled, limit, wait, maxWaiting)
+ if acquired {
+ return release, acquired
+ }
+ if !wait || notify == nil {
+ return nil, false
+ }
+ if !l.waitForSlot(ctx, notify) {
+ if waitRelease != nil {
+ waitRelease()
+ }
+ return nil, false
+ }
+ if waitRelease != nil {
+ waitRelease()
+ }
+ }
+}
+
+func (l *imageConcurrencyLimiter) tryAcquireLocked(enabled bool, limit int, wait bool, maxWaiting int) (func(), bool, func(), <-chan struct{}) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ if l.notify == nil {
+ l.notify = make(chan struct{})
+ }
+ if l.enabled != enabled || l.limit != limit {
+ l.enabled = enabled
+ l.limit = limit
+ }
+ if l.active < l.limit {
+ l.active++
+ return l.releaseFunc(), true, nil, nil
+ }
+ if !wait {
+ return nil, false, nil, nil
+ }
+ if maxWaiting > 0 && l.waiting >= maxWaiting {
+ return nil, false, nil, nil
+ }
+ l.waiting++
+ return nil, false, l.waiterReleaseFunc(), l.notify
+}
+
+func (l *imageConcurrencyLimiter) waitForSlot(ctx context.Context, notify <-chan struct{}) bool {
+ select {
+ case <-notify:
+ return true
+ case <-ctx.Done():
+ return false
+ }
+}
+
+func (l *imageConcurrencyLimiter) releaseFunc() func() {
+ var once sync.Once
+ return func() {
+ once.Do(func() {
+ l.mu.Lock()
+ if l.active > 0 {
+ l.active--
+ }
+ if l.notify != nil {
+ close(l.notify)
+ l.notify = make(chan struct{})
+ }
+ l.mu.Unlock()
+ })
+ }
+}
+
+func (l *imageConcurrencyLimiter) waiterReleaseFunc() func() {
+ var once sync.Once
+ return func() {
+ once.Do(func() {
+ l.mu.Lock()
+ if l.waiting > 0 {
+ l.waiting--
+ }
+ l.mu.Unlock()
+ })
+ }
+}
diff --git a/backend/internal/handler/image_concurrency_limiter_test.go b/backend/internal/handler/image_concurrency_limiter_test.go
new file mode 100644
index 00000000..20147f16
--- /dev/null
+++ b/backend/internal/handler/image_concurrency_limiter_test.go
@@ -0,0 +1,230 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestImageConcurrencyLimiter_DefaultDisabledAllowsRequests(t *testing.T) {
+ limiter := &imageConcurrencyLimiter{}
+
+ release, acquired := limiter.TryAcquire(false, 1)
+
+ require.True(t, acquired)
+ require.Nil(t, release)
+}
+
+func TestImageConcurrencyLimiter_RejectsWhenLimitReachedAndAllowsAfterRelease(t *testing.T) {
+ limiter := &imageConcurrencyLimiter{}
+
+ release, acquired := limiter.TryAcquire(true, 1)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+
+ secondRelease, secondAcquired := limiter.TryAcquire(true, 1)
+ require.False(t, secondAcquired)
+ require.Nil(t, secondRelease)
+
+ release()
+ thirdRelease, thirdAcquired := limiter.TryAcquire(true, 1)
+ require.True(t, thirdAcquired)
+ require.NotNil(t, thirdRelease)
+ thirdRelease()
+}
+
+func TestImageConcurrencyLimiter_WaitsUntilSlotReleased(t *testing.T) {
+ limiter := &imageConcurrencyLimiter{}
+ release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+
+ acquiredCh := make(chan func(), 1)
+ go func() {
+ waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+ require.True(t, waitAcquired)
+ acquiredCh <- waitRelease
+ }()
+
+ time.Sleep(20 * time.Millisecond)
+ release()
+
+ select {
+ case waitRelease := <-acquiredCh:
+ require.NotNil(t, waitRelease)
+ waitRelease()
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for image concurrency slot")
+ }
+}
+
+func TestImageConcurrencyLimiter_WaitTimesOut(t *testing.T) {
+ limiter := &imageConcurrencyLimiter{}
+ release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+ defer release()
+
+ waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, 10*time.Millisecond, 1)
+
+ require.False(t, waitAcquired)
+ require.Nil(t, waitRelease)
+}
+
+func TestImageConcurrencyLimiter_MaxWaitingRequestsRejectsOverflow(t *testing.T) {
+ limiter := &imageConcurrencyLimiter{}
+ release, acquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+ defer release()
+
+ waitingStarted := make(chan struct{})
+ waitingDone := make(chan struct{})
+ go func() {
+ close(waitingStarted)
+ waitRelease, waitAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+ if waitAcquired && waitRelease != nil {
+ waitRelease()
+ }
+ close(waitingDone)
+ }()
+ <-waitingStarted
+ time.Sleep(20 * time.Millisecond)
+
+ overflowRelease, overflowAcquired := limiter.Acquire(context.Background(), true, 1, true, time.Second, 1)
+
+ require.False(t, overflowAcquired)
+ require.Nil(t, overflowRelease)
+ release()
+ <-waitingDone
+}
+
+func TestOpenAIGatewayHandlerAcquireImageGenerationSlot_Returns429WhenFull(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/images/generations", nil)
+
+ h := &OpenAIGatewayHandler{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ ImageConcurrency: config.ImageConcurrencyConfig{
+ Enabled: true,
+ MaxConcurrentRequests: 1,
+ OverflowMode: config.ImageConcurrencyOverflowModeReject,
+ },
+ },
+ },
+ imageLimiter: &imageConcurrencyLimiter{},
+ }
+ release, acquired := h.acquireImageGenerationSlot(c, false)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+ defer release()
+
+ blockedRelease, blocked := h.acquireImageGenerationSlot(c, false)
+
+ require.False(t, blocked)
+ require.Nil(t, blockedRelease)
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+ require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
+ require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
+}
+
+func TestOpenAIGatewayHandlerResponses_ImageIntentRejectedByImageConcurrency(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := `{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
+ groupID := int64(1)
+ c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
+ ID: 10,
+ GroupID: &groupID,
+ Group: &service.Group{
+ ID: groupID,
+ AllowImageGeneration: true,
+ },
+ User: &service.User{ID: 20},
+ })
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
+
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
+ errorPassthroughService: nil,
+ cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
+ Enabled: true,
+ MaxConcurrentRequests: 1,
+ OverflowMode: config.ImageConcurrencyOverflowModeReject,
+ }}},
+ imageLimiter: &imageConcurrencyLimiter{},
+ }
+ release, acquired := h.acquireImageGenerationSlot(c, false)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+ defer release()
+ rec.Body.Reset()
+ rec.Code = 0
+
+ h.Responses(c)
+
+ require.Equal(t, http.StatusTooManyRequests, rec.Code)
+ require.Equal(t, "rate_limit_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
+ require.Contains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
+}
+
+func TestOpenAIGatewayHandlerResponses_TextOnlyNotRejectedByImageConcurrency(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := `{"model":"gpt-5.4","input":"write code"}`
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(body))
+ groupID := int64(1)
+ c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
+ ID: 10,
+ GroupID: &groupID,
+ Group: &service.Group{
+ ID: groupID,
+ AllowImageGeneration: true,
+ },
+ User: &service.User{ID: 20},
+ })
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 20, Concurrency: 1})
+
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, &config.Config{RunMode: config.RunModeSimple}),
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: &ConcurrencyHelper{concurrencyService: service.NewConcurrencyService(&helperConcurrencyCacheStub{userSeq: []bool{true}})},
+ cfg: &config.Config{Gateway: config.GatewayConfig{ImageConcurrency: config.ImageConcurrencyConfig{
+ Enabled: true,
+ MaxConcurrentRequests: 1,
+ OverflowMode: config.ImageConcurrencyOverflowModeReject,
+ }}},
+ imageLimiter: &imageConcurrencyLimiter{},
+ }
+ release, acquired := h.acquireImageGenerationSlot(c, false)
+ require.True(t, acquired)
+ require.NotNil(t, release)
+ defer release()
+ rec.Body.Reset()
+ rec.Code = 0
+
+ h.Responses(c)
+
+ require.NotEqual(t, http.StatusTooManyRequests, rec.Code)
+ require.NotContains(t, rec.Body.String(), "Image generation concurrency limit exceeded")
+}
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index 23844508..06ab9d52 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -187,52 +187,60 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- // Pool mode: retry on the same account
- if failoverErr.RetryableOnSameAccount {
- retryLimit := account.GetPoolModeRetryCount()
- if sameAccountRetryCount[account.ID] < retryLimit {
- sameAccountRetryCount[account.ID]++
- reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("retry_limit", retryLimit),
- zap.Int("retry_count", sameAccountRetryCount[account.ID]),
- )
- select {
- case <-c.Request.Context().Done():
- return
- case <-time.After(sameAccountRetryDelay):
- }
- continue
- }
- }
- h.gatewayService.RecordOpenAIAccountSwitch()
- failedAccountIDs[account.ID] = struct{}{}
- lastFailoverErr = failoverErr
- if switchCount >= maxAccountSwitches {
- h.handleFailoverExhausted(c, failoverErr, streamStarted)
- return
- }
- switchCount++
- reqLog.Warn("openai_chat_completions.upstream_failover_switching",
+ if result != nil && result.ImageCount > 0 {
+ reqLog.Warn("openai_chat_completions.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
+ zap.Int("image_count", result.ImageCount),
+ zap.Error(err),
)
- continue
+ } else {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ // Pool mode: retry on the same account
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
+ }
+ }
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai_chat_completions.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ reqLog.Warn("openai_chat_completions.forward_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ )
+ return
}
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
- reqLog.Warn("openai_chat_completions.forward_failed",
- zap.Int64("account_id", account.ID),
- zap.Bool("fallback_error_response_written", wroteFallback),
- zap.Error(err),
- )
- return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@@ -242,16 +250,18 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := resolveRawCCUpstreamEndpoint(c, account)
- h.submitUsageRecordTask(func(ctx context.Context) {
+ h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: resolveRawCCUpstreamEndpoint(c, account),
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index b5eec393..5966c163 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -33,6 +33,7 @@ type OpenAIGatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
+ imageLimiter *imageConcurrencyLimiter
maxAccountSwitches int
cfg *config.Config
}
@@ -69,6 +70,7 @@ func NewOpenAIGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
+ imageLimiter: &imageConcurrencyLimiter{},
maxAccountSwitches: maxAccountSwitches,
cfg: cfg,
}
@@ -187,6 +189,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ imageIntent := service.IsImageGenerationIntent("/v1/responses", reqModel, body)
+ if imageIntent && !service.GroupAllowsImageGeneration(apiKey.Group) {
+ h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
+ return
+ }
+ var imageReleaseFunc func()
+ if imageIntent {
+ var imageAcquired bool
+ imageReleaseFunc, imageAcquired = h.acquireImageGenerationSlot(c, streamStarted)
+ if !imageAcquired {
+ return
+ }
+ if imageReleaseFunc != nil {
+ defer imageReleaseFunc()
+ }
+ }
+
// 解析渠道级模型映射
channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
@@ -318,57 +337,65 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- // 池模式:同账号重试
- if failoverErr.RetryableOnSameAccount {
- retryLimit := account.GetPoolModeRetryCount()
- if sameAccountRetryCount[account.ID] < retryLimit {
- sameAccountRetryCount[account.ID]++
- reqLog.Warn("openai.pool_mode_same_account_retry",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("retry_limit", retryLimit),
- zap.Int("retry_count", sameAccountRetryCount[account.ID]),
- )
- select {
- case <-c.Request.Context().Done():
- return
- case <-time.After(sameAccountRetryDelay):
+ if result != nil && result.ImageCount > 0 {
+ reqLog.Warn("openai.forward_partial_error_with_image_result",
+ zap.Int64("account_id", account.ID),
+ zap.Int("image_count", result.ImageCount),
+ zap.Error(err),
+ )
+ } else {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ // 池模式:同账号重试
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
}
- continue
}
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
}
- h.gatewayService.RecordOpenAIAccountSwitch()
- failedAccountIDs[account.ID] = struct{}{}
- lastFailoverErr = failoverErr
- if switchCount >= maxAccountSwitches {
- h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.forward_failed", fields...)
return
}
- switchCount++
- reqLog.Warn("openai.upstream_failover_switching",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- )
- continue
- }
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Bool("fallback_error_response_written", wroteFallback),
- zap.Error(err),
- }
- if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
- reqLog.Warn("openai.forward_failed", fields...)
+ reqLog.Error("openai.forward_failed", fields...)
return
}
- reqLog.Error("openai.forward_failed", fields...)
- return
}
if result != nil {
if account.Type == service.AccountTypeOAuth {
@@ -383,17 +410,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
- h.submitUsageRecordTask(func(ctx context.Context) {
+ h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -701,52 +730,60 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- // 池模式:同账号重试
- if failoverErr.RetryableOnSameAccount {
- retryLimit := account.GetPoolModeRetryCount()
- if sameAccountRetryCount[account.ID] < retryLimit {
- sameAccountRetryCount[account.ID]++
- reqLog.Warn("openai_messages.pool_mode_same_account_retry",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("retry_limit", retryLimit),
- zap.Int("retry_count", sameAccountRetryCount[account.ID]),
- )
- select {
- case <-c.Request.Context().Done():
- return
- case <-time.After(sameAccountRetryDelay):
- }
- continue
- }
- }
- h.gatewayService.RecordOpenAIAccountSwitch()
- failedAccountIDs[account.ID] = struct{}{}
- lastFailoverErr = failoverErr
- if switchCount >= maxAccountSwitches {
- h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
- return
- }
- switchCount++
- reqLog.Warn("openai_messages.upstream_failover_switching",
+ if result != nil && result.ImageCount > 0 {
+ reqLog.Warn("openai_messages.forward_partial_error_with_image_result",
zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
+ zap.Int("image_count", result.ImageCount),
+ zap.Error(err),
)
- continue
+ } else {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ // 池模式:同账号重试
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai_messages.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
+ }
+ }
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai_messages.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
+ reqLog.Warn("openai_messages.forward_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ )
+ return
}
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- wroteFallback := h.ensureAnthropicErrorResponse(c, streamStarted)
- reqLog.Warn("openai_messages.forward_failed",
- zap.Int64("account_id", account.ID),
- zap.Bool("fallback_error_response_written", wroteFallback),
- zap.Error(err),
- )
- return
}
if result != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
@@ -757,16 +794,18 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
requestPayloadHash := service.HashUsageRequestPayload(body)
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
- h.submitUsageRecordTask(func(ctx context.Context) {
+ h.submitOpenAIUsageRecordTask(result, func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
@@ -1114,6 +1153,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
+ if service.IsImageGenerationIntent("/v1/responses", reqModel, firstMessage) && !service.GroupAllowsImageGeneration(apiKey.Group) {
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, service.ImageGenerationPermissionMessage())
+ return
+ }
+
// 解析渠道级模型映射
channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
@@ -1257,22 +1301,34 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
- if turnErr != nil || result == nil {
+ if turnErr != nil {
+ if result == nil || result.ImageCount <= 0 {
+ return
+ }
+ reqLog.Warn("openai.websocket_partial_error_with_image_result",
+ zap.Int64("account_id", account.ID),
+ zap.Int("image_count", result.ImageCount),
+ zap.Error(turnErr),
+ )
+ }
+ if result == nil {
return
}
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(ctx, account.ID, result.ResponseHeaders)
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
- h.submitUsageRecordTask(func(taskCtx context.Context) {
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
+ h.submitOpenAIUsageRecordTask(result, func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
@@ -1440,6 +1496,60 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
task(ctx)
}
+func (h *OpenAIGatewayHandler) submitOpenAIUsageRecordTask(result *service.OpenAIForwardResult, task service.UsageRecordTask) {
+ if result != nil && result.ImageCount > 0 {
+ h.submitMandatoryUsageRecordTask(task)
+ return
+ }
+ h.submitUsageRecordTask(task)
+}
+
+func (h *OpenAIGatewayHandler) submitMandatoryUsageRecordTask(task service.UsageRecordTask) {
+ if task == nil {
+ return
+ }
+ if h.usageRecordWorkerPool != nil {
+ if mode := h.usageRecordWorkerPool.Submit(task); mode != service.UsageRecordSubmitModeDropped {
+ return
+ }
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.usage"),
+ ).Warn("openai.usage_record_task_mandatory_sync_fallback")
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.usage"),
+ zap.Any("panic", recovered),
+ ).Error("openai.usage_record_task_panic_recovered")
+ }
+ }()
+ task(ctx)
+}
+
+func (h *OpenAIGatewayHandler) acquireImageGenerationSlot(c *gin.Context, streamStarted bool) (func(), bool) {
+ if h == nil || h.cfg == nil || h.imageLimiter == nil {
+ return nil, true
+ }
+ imageConcurrency := h.cfg.Gateway.ImageConcurrency
+ wait := strings.TrimSpace(imageConcurrency.OverflowMode) == config.ImageConcurrencyOverflowModeWait
+ release, acquired := h.imageLimiter.Acquire(
+ c.Request.Context(),
+ imageConcurrency.Enabled,
+ imageConcurrency.MaxConcurrentRequests,
+ wait,
+ time.Duration(imageConcurrency.WaitTimeoutSeconds)*time.Second,
+ imageConcurrency.MaxWaitingRequests,
+ )
+ if acquired {
+ return release, true
+ }
+ h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Image generation concurrency limit exceeded, please retry later", streamStarted)
+ return nil, false
+}
+
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
index 4d0078a7..eba701f1 100644
--- a/backend/internal/handler/openai_images.go
+++ b/backend/internal/handler/openai_images.go
@@ -81,6 +81,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.String("capability", string(parsed.RequiredCapability)),
)
+ if !service.GroupAllowsImageGeneration(apiKey.Group) {
+ h.errorResponse(c, http.StatusForbidden, "permission_error", service.ImageGenerationPermissionMessage())
+ return
+ }
+ imageReleaseFunc, acquired := h.acquireImageGenerationSlot(c, streamStarted)
+ if !acquired {
+ return
+ }
+ if imageReleaseFunc != nil {
+ defer imageReleaseFunc()
+ }
+
if parsed.Multipart {
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
} else {
@@ -188,62 +200,69 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
- if err == nil && result != nil && result.FirstTokenMs != nil {
+ if result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- if failoverErr.RetryableOnSameAccount {
- retryLimit := account.GetPoolModeRetryCount()
- if sameAccountRetryCount[account.ID] < retryLimit {
- sameAccountRetryCount[account.ID]++
- reqLog.Warn("openai.images.pool_mode_same_account_retry",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("retry_limit", retryLimit),
- zap.Int("retry_count", sameAccountRetryCount[account.ID]),
- )
- select {
- case <-c.Request.Context().Done():
- return
- case <-time.After(sameAccountRetryDelay):
+ if result != nil && result.ImageCount > 0 {
+ reqLog.Warn("openai.images.forward_partial_error_with_image_result",
+ zap.Int64("account_id", account.ID),
+ zap.Int("image_count", result.ImageCount),
+ zap.Error(err),
+ )
+ } else {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai.images.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
}
- continue
}
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai.images.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
}
- h.gatewayService.RecordOpenAIAccountSwitch()
- failedAccountIDs[account.ID] = struct{}{}
- lastFailoverErr = failoverErr
- if switchCount >= maxAccountSwitches {
- h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.images.forward_failed", fields...)
return
}
- switchCount++
- reqLog.Warn("openai.images.upstream_failover_switching",
- zap.Int64("account_id", account.ID),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- )
- continue
- }
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Bool("fallback_error_response_written", wroteFallback),
- zap.Error(err),
- }
- if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
- reqLog.Warn("openai.images.forward_failed", fields...)
+ reqLog.Error("openai.images.forward_failed", fields...)
return
}
- reqLog.Error("openai.images.forward_failed", fields...)
- return
}
-
if result != nil {
if account.Type == service.AccountTypeOAuth {
h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
@@ -259,21 +278,27 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
if parsed.Multipart {
requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
}
+ inboundEndpoint := GetInboundEndpoint(c)
+ upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
- h.submitUsageRecordTask(func(ctx context.Context) {
+ upstreamModel := ""
+ if result != nil {
+ upstreamModel = result.UpstreamModel
+ }
+ h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: account,
Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ InboundEndpoint: inboundEndpoint,
+ UpstreamEndpoint: upstreamEndpoint,
UserAgent: userAgent,
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
- ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
+ ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, upstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.images"),
diff --git a/backend/internal/handler/openai_images_controls_test.go b/backend/internal/handler/openai_images_controls_test.go
new file mode 100644
index 00000000..cebcccac
--- /dev/null
+++ b/backend/internal/handler/openai_images_controls_test.go
@@ -0,0 +1,49 @@
+package handler
+
+import (
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayHandlerImages_DisabledGroupRejectsBeforeScheduling(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw","size":"1024x1024"}`)
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+ groupID := int64(111)
+ c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
+ ID: 222,
+ GroupID: &groupID,
+ Group: &service.Group{
+ ID: groupID,
+ AllowImageGeneration: false,
+ },
+ User: &service.User{ID: 333},
+ })
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 333, Concurrency: 1})
+
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: &ConcurrencyHelper{concurrencyService: &service.ConcurrencyService{}},
+ }
+
+ h.Images(c)
+
+ require.Equal(t, http.StatusForbidden, rec.Code)
+ require.Equal(t, "permission_error", gjson.GetBytes(rec.Body.Bytes(), "error.type").String())
+ require.Contains(t, rec.Body.String(), service.ImageGenerationPermissionMessage())
+}
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index 5c945815..e4c2837a 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -129,3 +129,63 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
+
+func TestOpenAIGatewayHandlerSubmitMandatoryUsageRecordTask_DroppedTaskSyncFallback(t *testing.T) {
+ pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
+ WorkerCount: 1,
+ QueueSize: 1,
+ TaskTimeout: time.Second,
+ OverflowPolicy: "drop",
+ OverflowSamplePercent: 0,
+ AutoScaleEnabled: false,
+ })
+ t.Cleanup(pool.Stop)
+ h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
+
+ block := make(chan struct{})
+ release := make(chan struct{})
+ pool.Submit(func(ctx context.Context) {
+ close(block)
+ <-release
+ })
+ <-block
+ pool.Submit(func(ctx context.Context) {})
+
+ var called atomic.Bool
+ h.submitMandatoryUsageRecordTask(func(ctx context.Context) {
+ called.Store(true)
+ })
+ close(release)
+
+ require.True(t, called.Load(), "mandatory usage task must run synchronously when async submit is dropped")
+}
+
+func TestOpenAIGatewayHandlerSubmitOpenAIUsageRecordTask_ImageResultUsesMandatoryFallback(t *testing.T) {
+ pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{
+ WorkerCount: 1,
+ QueueSize: 1,
+ TaskTimeout: time.Second,
+ OverflowPolicy: "drop",
+ OverflowSamplePercent: 0,
+ AutoScaleEnabled: false,
+ })
+ t.Cleanup(pool.Stop)
+ h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
+
+ block := make(chan struct{})
+ release := make(chan struct{})
+ pool.Submit(func(ctx context.Context) {
+ close(block)
+ <-release
+ })
+ <-block
+ pool.Submit(func(ctx context.Context) {})
+
+ var called atomic.Bool
+ h.submitOpenAIUsageRecordTask(&service.OpenAIForwardResult{ImageCount: 1}, func(ctx context.Context) {
+ called.Store(true)
+ })
+ close(release)
+
+ require.True(t, called.Load(), "image usage task must be mandatory when async submit is dropped")
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 3a527405..68895475 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -166,6 +166,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
+ group.FieldAllowImageGeneration,
+ group.FieldImageRateIndependent,
+ group.FieldImageRateMultiplier,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
@@ -699,6 +702,9 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
+ AllowImageGeneration: g.AllowImageGeneration,
+ ImageRateIndependent: g.ImageRateIndependent,
+ ImageRateMultiplier: g.ImageRateMultiplier,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 5e16475a..112575f4 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -50,6 +50,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
+ SetAllowImageGeneration(groupIn.AllowImageGeneration).
+ SetImageRateIndependent(groupIn.ImageRateIndependent).
+ SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
@@ -120,6 +123,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
+ SetAllowImageGeneration(groupIn.AllowImageGeneration).
+ SetImageRateIndependent(groupIn.ImageRateIndependent).
+ SetImageRateMultiplier(groupIn.ImageRateMultiplier).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 607b93dc..34f560fc 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -328,6 +328,9 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
+ "allow_image_generation": false,
+ "image_rate_independent": false,
+ "image_rate_multiplier": 0,
"claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
index 90ff450f..221021d8 100644
--- a/backend/internal/service/account_stats_pricing.go
+++ b/backend/internal/service/account_stats_pricing.go
@@ -230,7 +230,11 @@ func applyAccountStatsCost(
if model == "" {
model = requestedModel
}
+ requestCount := 1
+ if usageLog != nil && usageLog.ImageCount > 0 {
+ requestCount = usageLog.ImageCount
+ }
usageLog.AccountStatsCost = resolveAccountStatsCost(
- ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
+ ctx, cs, bs, accountID, groupID, model, tokens, requestCount, totalCost,
)
}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index be4c23dc..793d60d8 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -189,11 +189,14 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ AllowImageGeneration bool
+ ImageRateIndependent bool
+ ImageRateMultiplier *float64
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -226,11 +229,14 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ AllowImageGeneration *bool
+ ImageRateIndependent *bool
+ ImageRateMultiplier *float64
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -1557,6 +1563,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
+ imageRateMultiplier := 1.0
+ if input.ImageRateMultiplier != nil {
+ if *input.ImageRateMultiplier < 0 {
+ return nil, errors.New("image_rate_multiplier must be >= 0")
+ }
+ imageRateMultiplier = *input.ImageRateMultiplier
+ }
// 校验降级分组
if input.FallbackGroupID != nil {
@@ -1624,6 +1637,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
+ AllowImageGeneration: input.AllowImageGeneration,
+ ImageRateIndependent: input.ImageRateIndependent,
+ ImageRateMultiplier: imageRateMultiplier,
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
@@ -1800,6 +1816,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
// 图片生成计费配置:负数表示清除(使用默认价格)
+ if input.AllowImageGeneration != nil {
+ group.AllowImageGeneration = *input.AllowImageGeneration
+ }
+ if input.ImageRateIndependent != nil {
+ group.ImageRateIndependent = *input.ImageRateIndependent
+ }
+ if input.ImageRateMultiplier != nil {
+ if *input.ImageRateMultiplier < 0 {
+ return nil, errors.New("image_rate_multiplier must be >= 0")
+ }
+ group.ImageRateMultiplier = *input.ImageRateMultiplier
+ }
if input.ImagePrice1K != nil {
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index eef02240..0a2020ea 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -266,6 +266,50 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
+func TestAdminService_UpdateGroup_PreservesImageGenerationControlsWhenOmitted(t *testing.T) {
+ imageMultiplier := 0.5
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-group",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ AllowImageGeneration: true,
+ ImageRateIndependent: true,
+ ImageRateMultiplier: imageMultiplier,
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ Description: "updated",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.NotNil(t, repo.updated)
+ require.True(t, repo.updated.AllowImageGeneration)
+ require.True(t, repo.updated.ImageRateIndependent)
+ require.InDelta(t, 0.5, repo.updated.ImageRateMultiplier, 1e-12)
+}
+
+func TestAdminService_UpdateGroup_RejectsNegativeImageRateMultiplier(t *testing.T) {
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-group",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ ImageRateMultiplier: 1,
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ svc := &adminServiceImpl{groupRepo: repo}
+ negative := -0.1
+
+ _, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ ImageRateMultiplier: &negative,
+ })
+ require.Error(t, err)
+ require.Nil(t, repo.updated)
+}
+
func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
existingGroup := &Group{
ID: 1,
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index 1a1c78b8..4432ad7d 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -63,6 +63,9 @@ type APIKeyAuthGroupSnapshot struct {
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
+ AllowImageGeneration bool `json:"allow_image_generation"`
+ ImageRateIndependent bool `json:"image_rate_independent"`
+ ImageRateMultiplier float64 `json:"image_rate_multiplier"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index 974ea66e..0f9d4214 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -14,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto"
)
-const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
+const apiKeyAuthSnapshotVersion = 8 // v8: added group image generation controls
type apiKeyAuthCacheConfig struct {
l1Size int
@@ -255,6 +255,9 @@ func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey)
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
+ AllowImageGeneration: apiKey.Group.AllowImageGeneration,
+ ImageRateIndependent: apiKey.Group.ImageRateIndependent,
+ ImageRateMultiplier: apiKey.Group.ImageRateMultiplier,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
@@ -321,6 +324,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
+ AllowImageGeneration: snapshot.Group.AllowImageGeneration,
+ ImageRateIndependent: snapshot.Group.ImageRateIndependent,
+ ImageRateMultiplier: snapshot.Group.ImageRateMultiplier,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index cb502a2e..a9c21884 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -294,8 +294,7 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
}
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
- if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
- normalized := normalizeCodexModel(modelLower)
+ if normalized := normalizeKnownOpenAICodexModel(modelLower); normalized != "" {
switch normalized {
case "gpt-5.5":
return s.fallbackPrices["gpt-5.5"]
@@ -644,13 +643,10 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens
}
func isOpenAIGPT54Model(model string) bool {
- trimmed := strings.TrimSpace(strings.ToLower(model))
- // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel
- // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。
- if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") {
- return false
- }
- normalized := normalizeCodexModel(trimmed)
+ // 仅当模型字符串实际属于已知 GPT-5/Codex 族时才做归一判定,避免
+ // normalizeCodexModel 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)
+ // 误识别为 gpt-5.4。
+ normalized := normalizeKnownOpenAICodexModel(model)
return normalized == "gpt-5.4" || normalized == "gpt-5.5"
}
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index 222abd69..df3e3a0a 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -137,6 +137,35 @@ func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
+func TestGetModelPricing_OpenAICompactAliasesFallback(t *testing.T) {
+ svc := newTestBillingService()
+
+ tests := []struct {
+ model string
+ inputPrice float64
+ outputPrice float64
+ cacheRead float64
+ longContext int
+ }{
+ {model: "gpt5.5", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
+ {model: "openai/gpt5.4", inputPrice: 2.5e-6, outputPrice: 15e-6, cacheRead: 0.25e-6, longContext: 272000},
+ {model: "gpt5.4-mini", inputPrice: 7.5e-7, outputPrice: 4.5e-6, cacheRead: 7.5e-8, longContext: 0},
+ {model: "gpt5.3codexspark", inputPrice: 1.5e-6, outputPrice: 12e-6, cacheRead: 0.15e-6, longContext: 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.model, func(t *testing.T) {
+ pricing, err := svc.GetModelPricing(tt.model)
+ require.NoError(t, err)
+ require.NotNil(t, pricing)
+ require.InDelta(t, tt.inputPrice, pricing.InputPricePerToken, 1e-12)
+ require.InDelta(t, tt.outputPrice, pricing.OutputPricePerToken, 1e-12)
+ require.InDelta(t, tt.cacheRead, pricing.CacheReadPricePerToken, 1e-12)
+ require.Equal(t, tt.longContext, pricing.LongContextInputThreshold)
+ })
+ }
+}
+
func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) {
svc := newTestBillingService()
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 67d19720..b9bd992e 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -8367,6 +8367,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
groupDefault := apiKey.Group.RateMultiplier
multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault)
}
+ imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
// 确定计费模型
billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel)
@@ -8384,7 +8385,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
}
// 计算费用
- cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, opts)
+ cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts)
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
@@ -8396,7 +8397,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
// 创建使用日志
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
- requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
+ requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
@@ -8450,11 +8451,12 @@ func (s *GatewayService) calculateRecordUsageCost(
apiKey *APIKey,
billingModel string,
multiplier float64,
+ imageMultiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// 图片生成计费
if result.ImageCount > 0 {
- return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier)
+ return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
}
// Token 计费
@@ -8495,7 +8497,8 @@ func (s *GatewayService) calculateImageCost(
Model: billingModel,
GroupID: &gid,
Tokens: tokens,
- RequestCount: 1,
+ RequestCount: result.ImageCount,
+ SizeTier: result.ImageSize,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
@@ -8580,6 +8583,7 @@ func (s *GatewayService) buildRecordUsageLog(
subscription *UserSubscription,
requestedModel string,
multiplier float64,
+ imageMultiplier float64,
accountRateMultiplier float64,
billingType int8,
cacheTTLOverridden bool,
@@ -8624,6 +8628,9 @@ func (s *GatewayService) buildRecordUsageLog(
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(),
}
+ if result.ImageCount > 0 {
+ usageLog.RateMultiplier = imageMultiplier
+ }
if cost != nil {
usageLog.InputCost = cost.InputCost
usageLog.OutputCost = cost.OutputCost
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index bb4c5aa1..f6155352 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -26,9 +26,12 @@ type Group struct {
DefaultValidityDays int
// 图片生成计费配置(antigravity 和 gemini 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
+ AllowImageGeneration bool
+ ImageRateIndependent bool
+ ImageRateMultiplier float64
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
// Claude Code 客户端限制
ClaudeCodeOnly bool
diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go
index 87174e03..93078aa6 100644
--- a/backend/internal/service/group_service.go
+++ b/backend/internal/service/group_service.go
@@ -45,19 +45,25 @@ type GroupSortOrderUpdate struct {
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
- Name string `json:"name"`
- Description string `json:"description"`
- RateMultiplier float64 `json:"rate_multiplier"`
- IsExclusive bool `json:"is_exclusive"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ IsExclusive bool `json:"is_exclusive"`
+ AllowImageGeneration bool `json:"allow_image_generation"`
+ ImageRateIndependent bool `json:"image_rate_independent"`
+ ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
- Name *string `json:"name"`
- Description *string `json:"description"`
- RateMultiplier *float64 `json:"rate_multiplier"`
- IsExclusive *bool `json:"is_exclusive"`
- Status *string `json:"status"`
+ Name *string `json:"name"`
+ Description *string `json:"description"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
+ IsExclusive *bool `json:"is_exclusive"`
+ Status *string `json:"status"`
+ AllowImageGeneration *bool `json:"allow_image_generation"`
+ ImageRateIndependent *bool `json:"image_rate_independent"`
+ ImageRateMultiplier *float64 `json:"image_rate_multiplier"`
}
// GroupService 分组管理服务
@@ -76,6 +82,13 @@ func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthC
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
+ imageRateMultiplier := 1.0
+ if req.ImageRateMultiplier != nil {
+ if *req.ImageRateMultiplier < 0 {
+ return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
+ }
+ imageRateMultiplier = *req.ImageRateMultiplier
+ }
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
@@ -87,13 +100,16 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Gro
// 创建分组
group := &Group{
- Name: req.Name,
- Description: req.Description,
- Platform: PlatformAnthropic,
- RateMultiplier: req.RateMultiplier,
- IsExclusive: req.IsExclusive,
- Status: StatusActive,
- SubscriptionType: SubscriptionTypeStandard,
+ Name: req.Name,
+ Description: req.Description,
+ Platform: PlatformAnthropic,
+ RateMultiplier: req.RateMultiplier,
+ IsExclusive: req.IsExclusive,
+ Status: StatusActive,
+ SubscriptionType: SubscriptionTypeStandard,
+ AllowImageGeneration: req.AllowImageGeneration,
+ ImageRateIndependent: req.ImageRateIndependent,
+ ImageRateMultiplier: imageRateMultiplier,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
@@ -165,6 +181,18 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if req.Status != nil {
group.Status = *req.Status
}
+ if req.AllowImageGeneration != nil {
+ group.AllowImageGeneration = *req.AllowImageGeneration
+ }
+ if req.ImageRateIndependent != nil {
+ group.ImageRateIndependent = *req.ImageRateIndependent
+ }
+ if req.ImageRateMultiplier != nil {
+ if *req.ImageRateMultiplier < 0 {
+ return nil, fmt.Errorf("image_rate_multiplier must be >= 0")
+ }
+ group.ImageRateMultiplier = *req.ImageRateMultiplier
+ }
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
diff --git a/backend/internal/service/image_billing_multiplier.go b/backend/internal/service/image_billing_multiplier.go
new file mode 100644
index 00000000..23ec5ac1
--- /dev/null
+++ b/backend/internal/service/image_billing_multiplier.go
@@ -0,0 +1,11 @@
+package service
+
+func resolveImageRateMultiplier(apiKey *APIKey, effectiveGroupMultiplier float64) float64 {
+ if apiKey != nil && apiKey.Group != nil && apiKey.Group.ImageRateIndependent {
+ if apiKey.Group.ImageRateMultiplier < 0 {
+ return 0
+ }
+ return apiKey.Group.ImageRateMultiplier
+ }
+ return effectiveGroupMultiplier
+}
diff --git a/backend/internal/service/image_generation_intent.go b/backend/internal/service/image_generation_intent.go
new file mode 100644
index 00000000..b6ef1065
--- /dev/null
+++ b/backend/internal/service/image_generation_intent.go
@@ -0,0 +1,220 @@
+package service
+
+import (
+ "encoding/json"
+ "strings"
+
+ "github.com/tidwall/gjson"
+)
+
+const (
+ openAIResponsesEndpoint = "/v1/responses"
+ openAIResponsesCompactEndpoint = "/v1/responses/compact"
+ imageGenerationPermissionMessage = "Image generation is not enabled for this group"
+)
+
+// ImageGenerationPermissionMessage returns the stable end-user error text for disabled groups.
+func ImageGenerationPermissionMessage() string {
+ return imageGenerationPermissionMessage
+}
+
+// GroupAllowsImageGeneration preserves ungrouped-key behavior and enforces the flag when a group is present.
+func GroupAllowsImageGeneration(group *Group) bool {
+ return group == nil || group.AllowImageGeneration
+}
+
+// IsImageGenerationIntent classifies requests that can produce generated images.
+func IsImageGenerationIntent(endpoint string, requestedModel string, body []byte) bool {
+ if IsImageGenerationEndpoint(endpoint) {
+ return true
+ }
+ if isOpenAIImageGenerationModel(requestedModel) {
+ return true
+ }
+ if len(body) == 0 || !gjson.ValidBytes(body) {
+ return false
+ }
+ if model := strings.TrimSpace(gjson.GetBytes(body, "model").String()); isOpenAIImageGenerationModel(model) {
+ return true
+ }
+ if openAIJSONToolsContainImageGeneration(gjson.GetBytes(body, "tools")) {
+ return true
+ }
+ return openAIJSONToolChoiceSelectsImageGeneration(gjson.GetBytes(body, "tool_choice"))
+}
+
+// IsImageGenerationIntentMap is the map-backed variant used after service-side request mutation.
+func IsImageGenerationIntentMap(endpoint string, requestedModel string, reqBody map[string]any) bool {
+ if IsImageGenerationEndpoint(endpoint) {
+ return true
+ }
+ if isOpenAIImageGenerationModel(requestedModel) {
+ return true
+ }
+ if reqBody == nil {
+ return false
+ }
+ if isOpenAIImageGenerationModel(firstNonEmptyString(reqBody["model"])) {
+ return true
+ }
+ if hasOpenAIImageGenerationTool(reqBody) {
+ return true
+ }
+ return openAIAnyToolChoiceSelectsImageGeneration(reqBody["tool_choice"])
+}
+
+// IsImageGenerationEndpoint identifies dedicated generated-image endpoints.
+func IsImageGenerationEndpoint(endpoint string) bool {
+ switch normalizeImageGenerationEndpoint(endpoint) {
+ case "/v1/images/generations", "/v1/images/edits", "/images/generations", "/images/edits":
+ return true
+ default:
+ return false
+ }
+}
+
+func normalizeImageGenerationEndpoint(endpoint string) string {
+ endpoint = strings.TrimSpace(strings.ToLower(endpoint))
+ if endpoint == "" {
+ return ""
+ }
+ endpoint = strings.TrimPrefix(endpoint, "https://api.openai.com")
+ if idx := strings.IndexByte(endpoint, '?'); idx >= 0 {
+ endpoint = endpoint[:idx]
+ }
+ return strings.TrimRight(endpoint, "/")
+}
+
+func openAIJSONToolsContainImageGeneration(tools gjson.Result) bool {
+ if !tools.IsArray() {
+ return false
+ }
+ found := false
+ tools.ForEach(func(_, item gjson.Result) bool {
+ if strings.TrimSpace(item.Get("type").String()) == "image_generation" {
+ found = true
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+func openAIJSONToolChoiceSelectsImageGeneration(choice gjson.Result) bool {
+ if !choice.Exists() {
+ return false
+ }
+ if choice.Type == gjson.String {
+ return strings.TrimSpace(choice.String()) == "image_generation"
+ }
+ if !choice.IsObject() {
+ return false
+ }
+ if strings.TrimSpace(choice.Get("type").String()) == "image_generation" {
+ return true
+ }
+ if strings.TrimSpace(choice.Get("tool.type").String()) == "image_generation" {
+ return true
+ }
+ if strings.TrimSpace(choice.Get("function.name").String()) == "image_generation" {
+ return true
+ }
+ return false
+}
+
+func openAIAnyToolChoiceSelectsImageGeneration(choice any) bool {
+ switch v := choice.(type) {
+ case string:
+ return strings.TrimSpace(v) == "image_generation"
+ case map[string]any:
+ if strings.TrimSpace(firstNonEmptyString(v["type"])) == "image_generation" {
+ return true
+ }
+ if tool, ok := v["tool"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(tool["type"])) == "image_generation" {
+ return true
+ }
+ if fn, ok := v["function"].(map[string]any); ok && strings.TrimSpace(firstNonEmptyString(fn["name"])) == "image_generation" {
+ return true
+ }
+ }
+ return false
+}
+
+func getAPIKeyFromContext(c interface{ Get(string) (any, bool) }) *APIKey {
+ if c == nil {
+ return nil
+ }
+ v, exists := c.Get("api_key")
+ if !exists {
+ return nil
+ }
+ apiKey, _ := v.(*APIKey)
+ return apiKey
+}
+
+func apiKeyGroup(apiKey *APIKey) *Group {
+ if apiKey == nil {
+ return nil
+ }
+ return apiKey.Group
+}
+
+func cloneRequestMapForImageIntent(body []byte) map[string]any {
+ if len(body) == 0 {
+ return nil
+ }
+ var out map[string]any
+ if err := json.Unmarshal(body, &out); err != nil {
+ return nil
+ }
+ return out
+}
+
+func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
+ imageModel := ""
+ imageSize := ""
+ hasImageTool := false
+ if reqBody != nil {
+ rawTools, _ := reqBody["tools"].([]any)
+ for _, rawTool := range rawTools {
+ toolMap, ok := rawTool.(map[string]any)
+ if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" {
+ continue
+ }
+ hasImageTool = true
+ imageModel = strings.TrimSpace(firstNonEmptyString(toolMap["model"]))
+ imageSize = strings.TrimSpace(firstNonEmptyString(toolMap["size"]))
+ break
+ }
+ if imageSize == "" {
+ imageSize = strings.TrimSpace(firstNonEmptyString(reqBody["size"]))
+ }
+ }
+ if imageModel == "" && reqBody != nil {
+ bodyModel := strings.TrimSpace(firstNonEmptyString(reqBody["model"]))
+ if isOpenAIImageBillingModelAlias(bodyModel) || !hasImageTool {
+ imageModel = bodyModel
+ }
+ }
+ if imageModel == "" && hasImageTool {
+ imageModel = "gpt-image-2"
+ }
+ if imageModel == "" {
+ imageModel = strings.TrimSpace(fallbackModel)
+ }
+ sizeTier := normalizeOpenAIImageSizeTier(imageSize)
+ return imageModel, sizeTier, nil
+}
+
+func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
+ reqBody := cloneRequestMapForImageIntent(body)
+ return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
+}
+
+func isOpenAIImageBillingModelAlias(model string) bool {
+ normalized := strings.ToLower(strings.TrimSpace(model))
+ if normalized == "" {
+ return false
+ }
+ return isOpenAIImageGenerationModel(normalized) || strings.Contains(normalized, "image")
+}
diff --git a/backend/internal/service/image_generation_intent_test.go b/backend/internal/service/image_generation_intent_test.go
new file mode 100644
index 00000000..5e7bec79
--- /dev/null
+++ b/backend/internal/service/image_generation_intent_test.go
@@ -0,0 +1,184 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsImageGenerationIntent(t *testing.T) {
+ tests := []struct {
+ name string
+ endpoint string
+ model string
+ body []byte
+ want bool
+ }{
+ {
+ name: "images endpoint",
+ endpoint: "/v1/images/generations",
+ body: []byte(`{"model":"gpt-image-2"}`),
+ want: true,
+ },
+ {
+ name: "image model",
+ endpoint: "/v1/responses",
+ model: "gpt-image-2",
+ body: []byte(`{"model":"gpt-image-2"}`),
+ want: true,
+ },
+ {
+ name: "image tool",
+ endpoint: "/v1/responses",
+ model: "gpt-5.4",
+ body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation"}]}`),
+ want: true,
+ },
+ {
+ name: "image tool choice",
+ endpoint: "/v1/responses",
+ model: "gpt-5.4",
+ body: []byte(`{"model":"gpt-5.4","tool_choice":{"type":"image_generation"}}`),
+ want: true,
+ },
+ {
+ name: "required tool choice alone is text",
+ endpoint: "/v1/responses",
+ model: "gpt-5.4",
+ body: []byte(`{"model":"gpt-5.4","tool_choice":"required"}`),
+ want: false,
+ },
+ {
+ name: "text only gpt 5.4",
+ endpoint: "/v1/responses",
+ model: "gpt-5.4",
+ body: []byte(`{"model":"gpt-5.4","input":"write code"}`),
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, IsImageGenerationIntent(tt.endpoint, tt.model, tt.body))
+ })
+ }
+}
+
+func TestResolveOpenAIResponsesImageBillingConfigUsesCurrentBodyModel(t *testing.T) {
+ imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
+ []byte(`{"model":"mapped-image-model","tools":[{"type":"image_generation","size":"1024x1024"}]}`),
+ "requested-model",
+ )
+ require.NoError(t, err)
+ require.Equal(t, "mapped-image-model", imageModel)
+ require.Equal(t, "1K", imageSize)
+}
+
+func TestResolveOpenAIResponsesImageBillingConfigToolModelWins(t *testing.T) {
+ imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
+ []byte(`{"model":"mapped-text-model","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1536x1024"}]}`),
+ "requested-model",
+ )
+ require.NoError(t, err)
+ require.Equal(t, "gpt-image-2", imageModel)
+ require.Equal(t, "2K", imageSize)
+}
+
+func TestResolveOpenAIResponsesImageBillingConfigSupportsOfficialAndCustomSizes(t *testing.T) {
+ tests := []struct {
+ name string
+ body []byte
+ wantTier string
+ }{
+ {
+ name: "official 2k landscape",
+ body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"2048x1152"}]}`),
+ wantTier: "2K",
+ },
+ {
+ name: "official 4k landscape",
+ body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-2","size":"3840x2160"}]}`),
+ wantTier: "4K",
+ },
+ {
+ name: "custom valid 2k",
+ body: []byte(`{"model":"gpt-5.5","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1280x768"}]}`),
+ wantTier: "2K",
+ },
+ {
+ name: "default image tool model supports flexible size",
+ body: []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","size":"2048x1152"}]}`),
+ wantTier: "2K",
+ },
+ {
+ name: "top level image size is moved into billing",
+ body: []byte(`{"model":"gpt-image-2","size":"2048x2048","tools":[{"type":"image_generation","model":"gpt-image-2"}]}`),
+ wantTier: "2K",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(tt.body, "requested-model")
+ require.NoError(t, err)
+ require.NotEmpty(t, imageModel)
+ require.Equal(t, tt.wantTier, imageSize)
+ })
+ }
+}
+
+func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *testing.T) {
+ imageModel, imageSize, err := resolveOpenAIResponsesImageBillingConfigFromBody(
+ []byte(`{"model":"gpt-5.4","tools":[{"type":"image_generation","model":"gpt-image-1.5","size":"2048x1152"}]}`),
+ "requested-model",
+ )
+ require.NoError(t, err)
+ require.Equal(t, "gpt-image-1.5", imageModel)
+ require.Equal(t, "2K", imageSize)
+}
+
+func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
+ counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
+ counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
+ require.Equal(t, 2, counter.Count())
+}
+
+func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEData([]byte(`{"type":"image_generation.completed","id":"ig_complete","b64_json":"final-a"}`))
+ counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_item","type":"image_generation_call","result":"final-b"}}`))
+ counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_done","type":"image_generation_call","result":"final-c"}]}}`))
+ require.Equal(t, 3, counter.Count())
+
+ dataCounter := newOpenAIImageOutputCounter()
+ dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"}]}`))
+ dataCounter.AddSSEData([]byte(`{"data":[{"b64_json":"a"},{"b64_json":"b"},{"b64_json":"c"}]}`))
+ require.Equal(t, 3, dataCounter.Count())
+}
+
+func TestOpenAIImageOutputCounterCountsMultilineSSEDataPayload(t *testing.T) {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEData([]byte("{\"type\":\"image_generation.completed\",\n\"b64_json\":\"final-a\"}"))
+ require.Equal(t, 1, counter.Count())
+}
+
+func TestOpenAIImageOutputCounterCountsMultilineSSEBodyPayload(t *testing.T) {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEBody(
+ "data: {\"type\":\"image_generation.completed\",\n" +
+ "data: \"b64_json\":\"final-a\"}\n\n" +
+ "data: [DONE]\n\n",
+ )
+ require.Equal(t, 1, counter.Count())
+}
+
+func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.T) {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEBody(
+ "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-a\"}\n" +
+ "data: {\"type\":\"image_generation.completed\",\"b64_json\":\"final-b\"}\n\n",
+ )
+ require.Equal(t, 2, counter.Count())
+}
diff --git a/backend/internal/service/image_output_accounting.go b/backend/internal/service/image_output_accounting.go
new file mode 100644
index 00000000..219c0c59
--- /dev/null
+++ b/backend/internal/service/image_output_accounting.go
@@ -0,0 +1,149 @@
+package service
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "strings"
+
+ "github.com/tidwall/gjson"
+)
+
+type openAIImageOutputCounter struct {
+ seen map[string]struct{}
+ count int
+ maxDataCount int
+}
+
+func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
+ return &openAIImageOutputCounter{seen: make(map[string]struct{})}
+}
+
+func (c *openAIImageOutputCounter) Count() int {
+ if c == nil {
+ return 0
+ }
+ if c.maxDataCount > c.count {
+ return c.maxDataCount
+ }
+ return c.count
+}
+
+func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
+ if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
+ return
+ }
+ c.addDataArray(gjson.GetBytes(body, "data"))
+ c.addOutputArray(gjson.GetBytes(body, "output"))
+ c.addOutputArray(gjson.GetBytes(body, "response.output"))
+}
+
+func (c *openAIImageOutputCounter) AddSSEData(data []byte) {
+ if c == nil || len(data) == 0 || strings.TrimSpace(string(data)) == "[DONE]" || !gjson.ValidBytes(data) {
+ return
+ }
+ root := gjson.ParseBytes(data)
+ c.addDataArray(root.Get("data"))
+ eventType := strings.TrimSpace(root.Get("type").String())
+ switch eventType {
+ case "response.output_item.done":
+ c.addImageOutputItem(root.Get("item"))
+ case "response.completed", "response.done":
+ c.addOutputArray(root.Get("response.output"))
+ case "image_generation.completed":
+ if item := root.Get("item"); item.Exists() {
+ c.addImageOutputItem(item)
+ return
+ }
+ if output := root.Get("output"); output.Exists() {
+ c.addImageOutputItem(output)
+ return
+ }
+ c.addImageOutputItem(root)
+ }
+}
+
+func (c *openAIImageOutputCounter) AddSSEBody(body string) {
+ if c == nil || strings.TrimSpace(body) == "" {
+ return
+ }
+ forEachOpenAISSEDataPayload(body, c.AddSSEData)
+}
+
+func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
+ if !data.IsArray() {
+ return
+ }
+ count := len(data.Array())
+ if count > c.maxDataCount {
+ c.maxDataCount = count
+ }
+}
+
+func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
+ if !output.IsArray() {
+ return
+ }
+ output.ForEach(func(_, item gjson.Result) bool {
+ c.addImageOutputItem(item)
+ return true
+ })
+}
+
+func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
+ if !item.Exists() || !item.IsObject() {
+ return
+ }
+ itemType := strings.TrimSpace(item.Get("type").String())
+ if itemType != "" && itemType != "image_generation_call" && itemType != "image_generation.completed" {
+ return
+ }
+ if strings.Contains(strings.ToLower(item.Raw), "partial_image") {
+ return
+ }
+ result := strings.TrimSpace(item.Get("result").String())
+ if result == "" {
+ result = strings.TrimSpace(item.Get("b64_json").String())
+ }
+ if result == "" {
+ result = strings.TrimSpace(item.Get("url").String())
+ }
+ if result == "" && itemType != "image_generation.completed" {
+ return
+ }
+ key := strings.TrimSpace(item.Get("id").String())
+ if key == "" {
+ key = strings.TrimSpace(item.Get("call_id").String())
+ }
+ if key == "" {
+ key = hashOpenAIImageOutputResult(result)
+ }
+ if key == "" {
+ return
+ }
+ if _, exists := c.seen[key]; exists {
+ return
+ }
+ c.seen[key] = struct{}{}
+ c.count++
+}
+
+func hashOpenAIImageOutputResult(result string) string {
+ result = strings.TrimSpace(result)
+ if result == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(result))
+ return hex.EncodeToString(sum[:])
+}
+
+func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddJSONResponse(body)
+ return counter.Count()
+}
+
+func countOpenAIImageOutputsFromSSEBody(body string) int {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEBody(body)
+ return counter.Count()
+}
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 7a0a6636..f29d3607 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"hash/fnv"
+ "log/slog"
"math"
"sort"
"strconv"
@@ -345,7 +346,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
_ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
return nil, nil
}
- if !s.isAccountRequestCompatible(account, req) {
+ if !s.isAccountRequestCompatible(ctx, account, req) {
return nil, nil
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@@ -621,7 +622,7 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
}
- if !s.isAccountRequestCompatible(account, req) {
+ if !s.isAccountRequestCompatible(ctx, account, req) {
continue
}
if !s.isAccountTransportCompatible(account, req.RequiredTransport) {
@@ -822,11 +823,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@@ -853,11 +854,11 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
// WaitPlan.MaxConcurrency 使用 Concurrency(非 EffectiveLoadFactor),因为 WaitPlan 控制的是 Redis 实际并发槽位等待。
for _, candidate := range selectionOrder {
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
- if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(fresh, req) {
+ if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
@@ -888,13 +889,18 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
return s.service.isOpenAIAccountTransportCompatible(account, requiredTransport)
}
-func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Account, req OpenAIAccountScheduleRequest) bool {
+func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.Context, account *Account, req OpenAIAccountScheduleRequest) bool {
if account == nil {
return false
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false
}
+ if req.GroupID != nil && s != nil && s.service != nil &&
+ s.service.needsUpstreamChannelRestrictionCheck(ctx, req.GroupID) &&
+ s.service.isUpstreamModelRestrictedByChannel(ctx, *req.GroupID, account, req.RequestedModel, req.RequireCompact) {
+ return false
+ }
return account.SupportsOpenAIImageCapability(req.RequiredImageCapability)
}
@@ -1106,6 +1112,13 @@ func (s *OpenAIGatewayService) selectAccountWithScheduler(
}
}
+ if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) {
+ slog.Warn("channel pricing restriction blocked request",
+ "group_id", derefGroupID(groupID),
+ "model", requestedModel)
+ return nil, decision, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel)
+ }
+
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 {
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index de98b50d..f96bf81f 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -485,12 +485,14 @@ func normalizeKnownCodexModel(model string) (string, bool) {
return model, true
}
- modelID := model
- if strings.Contains(modelID, "/") {
- parts := strings.Split(modelID, "/")
- modelID = parts[len(parts)-1]
- }
+ modelID := lastOpenAIModelSegment(model)
+ if normalized := canonicalizeOpenAIModelAliasSpelling(modelID); normalized != "" {
+ modelID = normalized
+ }
+ if mapped := normalizeKnownOpenAICodexModel(modelID); mapped != "" {
+ return mapped, true
+ }
key := codexModelLookupKey(modelID)
if key == "" {
return "", false
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 87bb7162..3add4779 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -804,15 +804,25 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.4": "gpt-5.4",
+ "gpt5.5": "gpt-5.5",
+ "openai/gpt5.5": "gpt-5.5",
+ "gpt5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
"gpt 5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
+ "gpt5.4-mini": "gpt-5.4-mini",
+ "gpt5.4mini": "gpt-5.4-mini",
"gpt 5.4 mini": "gpt-5.4-mini",
"gpt-5.3": "gpt-5.3-codex",
+ "gpt5.3": "gpt-5.3-codex",
"gpt-5.3-codex": "gpt-5.3-codex",
+ "gpt5.3-codex": "gpt-5.3-codex",
+ "gpt5.3codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt5.3-codex-spark": "gpt-5.3-codex-spark",
+ "gpt5.3codexspark": "gpt-5.3-codex-spark",
"gpt 5.3 codex spark": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark",
"gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark",
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index 4722c82d..3791c5a8 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -52,6 +52,12 @@ func (s *openAIRecordUsageBillingRepoStub) Apply(ctx context.Context, cmd *Usage
return &UsageBillingApplyResult{Applied: true}, nil
}
+func TestOpenAIGatewayServiceRecordUsage_RejectsNilInput(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.Error(t, svc.RecordUsage(context.Background(), nil))
+ require.Error(t, svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{}))
+}
+
type openAIRecordUsageUserRepoStub struct {
UserRepository
@@ -1081,6 +1087,101 @@ func TestOpenAIGatewayServiceRecordUsage_ChannelMappedOverridesBillingModelWhenM
require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
}
+func TestOpenAIGatewayServiceRecordUsage_BillsCompactOpenAIModelAlias(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+ usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
+
+ expectedCost, err := svc.billingService.CalculateCost("gpt-5.5", UsageTokens{
+ InputTokens: 20,
+ OutputTokens: 10,
+ }, 1.1)
+ require.NoError(t, err)
+
+ err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_compact_openai_alias",
+ Model: "gpt5.5",
+ UpstreamModel: "gpt-5.4",
+ Usage: usage,
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 10},
+ User: &User{ID: 20},
+ Account: &Account{ID: 30},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, "gpt5.5", usageRepo.lastLog.Model)
+ require.NotNil(t, usageRepo.lastLog.UpstreamModel)
+ require.Equal(t, "gpt-5.4", *usageRepo.lastLog.UpstreamModel)
+ require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
+ require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
+ require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_FallsBackToUpstreamModelWhenPrimaryUnpriceable(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+ usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10}
+
+ expectedCost, err := svc.billingService.CalculateCost("gpt-5.4", UsageTokens{
+ InputTokens: 20,
+ OutputTokens: 10,
+ }, 1.1)
+ require.NoError(t, err)
+
+ err = svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_unpriceable_primary_upstream_fallback",
+ Model: "not-priceable-alias",
+ BillingModel: "not-priceable-alias",
+ UpstreamModel: "gpt-5.4",
+ Usage: usage,
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 10},
+ User: &User{ID: 20},
+ Account: &Account{ID: 30},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, expectedCost.ActualCost, usageRepo.lastLog.ActualCost, 1e-12)
+ require.True(t, usageRepo.lastLog.ActualCost > 0, "cost must not be zero")
+ require.InDelta(t, expectedCost.ActualCost, userRepo.lastAmount, 1e-12)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ReturnsErrorWhenTokenModelCannotBePriced(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_unpriceable_without_upstream",
+ Model: "not-priceable-alias",
+ Usage: OpenAIUsage{InputTokens: 20, OutputTokens: 10},
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 10},
+ User: &User{ID: 20},
+ Account: &Account{ID: 30},
+ })
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "calculate OpenAI usage cost failed")
+ require.Equal(t, 0, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFields(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
userRepo := &openAIRecordUsageUserRepoStub{}
@@ -1209,3 +1310,278 @@ func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTo
require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12)
require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12)
}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierPreservesExistingBehavior(t *testing.T) {
+ imagePrice := 0.2
+ groupID := int64(121)
+
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_shared_multiplier",
+ Model: "gpt-image-2",
+ ImageCount: 1,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 10121,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 0.15,
+ ImageRateIndependent: false,
+ ImageRateMultiplier: 1,
+ ImagePrice1K: &imagePrice,
+ },
+ },
+ User: &User{ID: 20121},
+ Account: &Account{ID: 30121},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.03, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageSharedMultiplierUsesUserGroupOverride(t *testing.T) {
+ imagePrice := 0.5
+ userRate := 0.2
+ groupID := int64(125)
+
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(
+ usageRepo,
+ &openAIRecordUsageUserRepoStub{},
+ &openAIRecordUsageSubRepoStub{},
+ &openAIUserGroupRateRepoStub{rate: &userRate},
+ )
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_user_group_override",
+ Model: "gpt-image-2",
+ ImageCount: 1,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 10125,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 0.15,
+ ImageRateIndependent: false,
+ ImageRateMultiplier: 1,
+ ImagePrice1K: &imagePrice,
+ },
+ },
+ User: &User{ID: 20125},
+ Account: &Account{ID: 30125},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, 0.5, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.1, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 0.2, usageRepo.lastLog.RateMultiplier, 1e-12)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ImageIndependentMultiplierUsesImageRate(t *testing.T) {
+ imagePrice := 0.2
+ groupID := int64(122)
+
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_independent_multiplier",
+ Model: "gpt-image-2",
+ ImageCount: 1,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 10122,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 0.15,
+ ImageRateIndependent: true,
+ ImageRateMultiplier: 1,
+ ImagePrice1K: &imagePrice,
+ },
+ },
+ User: &User{ID: 20122},
+ Account: &Account{ID: 30122},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, 0.2, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.2, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndSharedMultiplier(t *testing.T) {
+ groupID := int64(123)
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+ svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_channel_shared",
+ Model: "gpt-image-2",
+ ImageCount: 3,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 10123,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 0.15,
+ ImageRateIndependent: false,
+ ImageRateMultiplier: 1,
+ },
+ },
+ User: &User{ID: 20123},
+ Account: &Account{ID: 30123},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.1125, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 0.15, usageRepo.lastLog.RateMultiplier, 1e-12)
+ require.Equal(t, 3, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_ChannelImageBillingUsesImageCountAndIndependentMultiplier(t *testing.T) {
+ groupID := int64(124)
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+ svc.resolver = newOpenAIImageChannelPricingResolverForTest(t, groupID, "gpt-image-2", 0.25)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_channel_independent",
+ Model: "gpt-image-2",
+ ImageCount: 3,
+ ImageSize: "1K",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 10124,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 0.15,
+ ImageRateIndependent: true,
+ ImageRateMultiplier: 1,
+ },
+ },
+ User: &User{ID: 20124},
+ Account: &Account{ID: 30124},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.InDelta(t, 0.75, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.75, usageRepo.lastLog.ActualCost, 1e-12)
+ require.InDelta(t, 1.0, usageRepo.lastLog.RateMultiplier, 1e-12)
+ require.Equal(t, 3, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func newOpenAIImageChannelPricingResolverForTest(t *testing.T, groupID int64, model string, price float64) *ModelPricingResolver {
+ t.Helper()
+ cache := newEmptyChannelCache()
+ cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: model}] = &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: &price,
+ }
+ cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
+ cache.groupPlatform[groupID] = ""
+ cache.loadedAt = time.Now()
+ cs := &ChannelService{}
+ cs.cache.Store(cache)
+ return NewModelPricingResolver(cs, NewBillingService(&config.Config{}, nil))
+}
+
+func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesImageCount(t *testing.T) {
+ groupID := int64(126)
+ billingService := NewBillingService(&config.Config{}, nil)
+ svc := &GatewayService{
+ billingService: billingService,
+ resolver: newOpenAIImageChannelPricingResolverForTest(t, groupID, "gemini-image", 0.25),
+ }
+
+ cost := svc.calculateRecordUsageCost(
+ context.Background(),
+ &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "1K"},
+ &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
+ "gemini-image",
+ 0.15,
+ 1.0,
+ nil,
+ )
+
+ require.NotNil(t, cost)
+ require.Equal(t, string(BillingModeImage), cost.BillingMode)
+ require.InDelta(t, 0.5, cost.TotalCost, 1e-12)
+ require.InDelta(t, 0.5, cost.ActualCost, 1e-12)
+}
+
+func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(t *testing.T) {
+ groupID := int64(127)
+ defaultPrice := 0.10
+ price4K := 0.40
+ cache := newEmptyChannelCache()
+ cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: &defaultPrice,
+ Intervals: []PricingInterval{{
+ TierLabel: "4K",
+ PerRequestPrice: &price4K,
+ }},
+ }
+ cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
+ cache.loadedAt = time.Now()
+ channelService := &ChannelService{}
+ channelService.cache.Store(cache)
+
+ svc := &GatewayService{
+ billingService: NewBillingService(&config.Config{}, nil),
+ resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
+ }
+
+ cost := svc.calculateRecordUsageCost(
+ context.Background(),
+ &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: "4K"},
+ &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
+ "gemini-image",
+ 1.0,
+ 1.0,
+ nil,
+ )
+
+ require.NotNil(t, cost)
+ require.Equal(t, string(BillingModeImage), cost.BillingMode)
+ require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
+ require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index b818fa4a..edd821ce 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -2049,6 +2049,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
promptCacheKey = strings.TrimSpace(v)
}
}
+ apiKey := getAPIKeyFromContext(c)
+ imageGenerationAllowed := GroupAllowsImageGeneration(nil)
+ if apiKey != nil {
+ imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group)
+ }
+ if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
+ setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": gin.H{
+ "type": "permission_error",
+ "message": ImageGenerationPermissionMessage(),
+ },
+ })
+ return nil, errors.New("image generation disabled for group")
+ }
// Track if body needs re-serialization
bodyModified := false
@@ -2108,7 +2123,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("instructions", "You are a helpful coding assistant.")
}
- if isCodexCLI && ensureOpenAIResponsesImageGenerationTool(reqBody) {
+ if isCodexCLI && imageGenerationAllowed && ensureOpenAIResponsesImageGenerationTool(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client")
@@ -2119,7 +2134,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload")
}
- if isCodexCLI && applyCodexImageGenerationBridgeInstructions(reqBody) {
+ if isCodexCLI && imageGenerationAllowed && applyCodexImageGenerationBridgeInstructions(reqBody) {
bodyModified = true
disablePatch()
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions")
@@ -2134,7 +2149,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
markPatchSet("model", billingModel)
}
upstreamModel := billingModel
- if normalizeOpenAIResponsesImageOnlyModel(reqBody) {
+ if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) {
bodyModified = true
disablePatch()
if model, ok := reqBody["model"].(string); ok {
@@ -2355,6 +2370,34 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
+ if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed {
+ setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": gin.H{
+ "type": "permission_error",
+ "message": ImageGenerationPermissionMessage(),
+ },
+ })
+ return nil, errors.New("image generation disabled for group")
+ }
+ imageBillingModel := ""
+ imageSizeTier := ""
+ if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
+ var imageCfgErr error
+ imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
+ if imageCfgErr != nil {
+ setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": imageCfgErr.Error(),
+ "param": "size",
+ },
+ })
+ return nil, imageCfgErr
+ }
+ }
+
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@@ -2592,6 +2635,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
wsAttempts,
)
wsResult.UpstreamModel = upstreamModel
+ if wsResult.ImageCount > 0 {
+ wsResult.ImageSize = imageSizeTier
+ wsResult.BillingModel = imageBillingModel
+ }
return wsResult, nil
}
s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr)
@@ -2695,6 +2742,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
+ imageCount := 0
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil {
@@ -2702,11 +2750,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
+ imageCount = streamResult.imageCount
} else {
- usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
+ nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil {
return nil, err
}
+ usage = nonStreamResult.usage
+ imageCount = nonStreamResult.imageCount
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
@@ -2723,7 +2774,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
serviceTier := extractOpenAIServiceTier(reqBody)
- return &OpenAIForwardResult{
+ forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
@@ -2734,7 +2785,13 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
- }, nil
+ }
+ if imageCount > 0 {
+ forwardResult.ImageCount = imageCount
+ forwardResult.ImageSize = imageSizeTier
+ forwardResult.BillingModel = imageBillingModel
+ }
+ return forwardResult, nil
}
}
@@ -2823,6 +2880,35 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
body = updatedBody
+ apiKey := getAPIKeyFromContext(c)
+ if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) {
+ setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "")
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": gin.H{
+ "type": "permission_error",
+ "message": ImageGenerationPermissionMessage(),
+ },
+ })
+ return nil, errors.New("image generation disabled for group")
+ }
+ imageBillingModel := ""
+ imageSizeTier := ""
+ if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
+ var imageCfgErr error
+ imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
+ if imageCfgErr != nil {
+ setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": imageCfgErr.Error(),
+ "param": "size",
+ },
+ })
+ return nil, imageCfgErr
+ }
+ }
+
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@@ -2905,6 +2991,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
var usage *OpenAIUsage
var firstTokenMs *int
+ imageCount := 0
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil {
@@ -2912,11 +2999,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
usage = result.usage
firstTokenMs = result.firstTokenMs
+ imageCount = result.imageCount
} else {
- usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
+ result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil {
return nil, err
}
+ usage = result.usage
+ imageCount = result.imageCount
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
@@ -2927,7 +3017,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
usage = &OpenAIUsage{}
}
- return &OpenAIForwardResult{
+ forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: reqModel,
@@ -2938,7 +3028,13 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
OpenAIWSMode: false,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
- }, nil
+ }
+ if imageCount > 0 {
+ forwardResult.ImageCount = imageCount
+ forwardResult.ImageSize = imageSizeTier
+ forwardResult.BillingModel = imageBillingModel
+ }
+ return forwardResult, nil
}
func logOpenAIPassthroughInstructionsRejected(
@@ -3233,6 +3329,13 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
type openaiStreamingResultPassthrough struct {
usage *OpenAIUsage
firstTokenMs *int
+ imageCount int
+}
+
+type openaiNonStreamingResultPassthrough struct {
+ *OpenAIUsage
+ usage *OpenAIUsage
+ imageCount int
}
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
@@ -3369,6 +3472,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
usage := &OpenAIUsage{}
+ imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
clientDisconnected := false
sawDone := false
@@ -3400,6 +3504,9 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
defer putSSEScannerBuf64K(scanBuf)
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
+ resultWithUsage := func() *openaiStreamingResultPassthrough {
+ return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
+ }
for scanner.Scan() {
line := scanner.Text()
@@ -3419,7 +3526,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if eventType == "response.failed" {
failedMessage = extractOpenAISSEErrorMessage(dataBytes)
if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage)
}
forceFlushFailedEvent = true
@@ -3431,6 +3538,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
if openAIStreamEventIsTerminal(trimmedData) {
sawTerminalEvent = true
}
+ imageCounter.AddSSEData(dataBytes)
lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType)
if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
@@ -3460,28 +3568,28 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
}
if err := scanner.Err(); err != nil {
if sawTerminalEvent && !sawFailedEvent {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
+ return resultWithUsage(), nil
}
if sawFailedEvent {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete: %w", err)
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
if errors.Is(err, bufio.ErrTooLong) {
logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err)
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err
+ return resultWithUsage(), err
}
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
msg := "OpenAI stream disconnected before completion"
if errText := strings.TrimSpace(err.Error()); errText != "" {
msg += ": " + errText
}
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg)
}
if clientDisconnected {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream usage incomplete after disconnect: %w", err)
+ return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err)
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v",
@@ -3489,10 +3597,10 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
upstreamRequestID,
err,
)
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
+ return resultWithUsage(), fmt.Errorf("stream read error: %w", err)
}
if sawFailedEvent {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("upstream response failed: %s", failedMessage)
+ return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage)
}
if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil {
logger.FromContext(ctx).With(
@@ -3501,13 +3609,13 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
zap.String("upstream_request_id", upstreamRequestID),
).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流")
if !openAIStreamClientOutputStarted(c, clientOutputStarted) {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs},
+ return resultWithUsage(),
s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event")
}
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, errors.New("stream usage incomplete: missing terminal event")
+ return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event")
}
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil
+ return resultWithUsage(), nil
}
func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
@@ -3516,7 +3624,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
c *gin.Context,
originalModel string,
mappedModel string,
-) (*OpenAIUsage, error) {
+) (*openaiNonStreamingResultPassthrough, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return nil, err
@@ -3553,14 +3661,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
c.Data(resp.StatusCode, contentType, body)
- return usage, nil
+ return &openaiNonStreamingResultPassthrough{
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ }, nil
}
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
// response for the passthrough path. It mirrors handleSSEToJSON while
// preserving passthrough payloads, except compact-only model remapping may
// rewrite model fields back to the original requested model.
-func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*OpenAIUsage, error) {
+func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@@ -3611,7 +3723,11 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
}
c.Data(resp.StatusCode, contentType, body)
- return usage, nil
+ return &openaiNonStreamingResultPassthrough{
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ }, nil
}
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
@@ -4025,6 +4141,13 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
type openaiStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
+ imageCount int
+}
+
+type openaiNonStreamingResult struct {
+ *OpenAIUsage
+ usage *OpenAIUsage
+ imageCount int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
@@ -4058,6 +4181,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
usage := &OpenAIUsage{}
+ imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -4136,7 +4260,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
needModelReplace := originalModel != mappedModel
resultWithUsage := func() *openaiStreamingResult {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
}
finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent {
@@ -4231,6 +4355,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
forceFlushFailedEvent = true
sawFailedEvent = true
}
+ imageCounter.AddSSEData(dataBytes)
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
@@ -4496,7 +4621,7 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
}, true
}
-func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
+func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
return nil, err
@@ -4542,7 +4667,11 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
c.Data(resp.StatusCode, contentType, body)
- return usage, nil
+ return &openaiNonStreamingResult{
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ }, nil
}
func isEventStreamResponse(header http.Header) bool {
@@ -4550,7 +4679,7 @@ func isEventStreamResponse(header http.Header) bool {
return strings.Contains(contentType, "text/event-stream")
}
-func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
+func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*openaiNonStreamingResult, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@@ -4602,21 +4731,29 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
}
c.Data(resp.StatusCode, contentType, body)
- return usage, nil
+ return &openaiNonStreamingResult{
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ }, nil
}
func extractOpenAISSETerminalEvent(body string) (string, []byte, bool) {
- lines := strings.Split(body, "\n")
- for _, line := range lines {
- data, ok := extractOpenAISSEDataLine(line)
- if !ok || data == "" || data == "[DONE]" {
- continue
+ var terminalType string
+ var terminalPayload []byte
+ forEachOpenAISSEDataPayload(body, func(data []byte) {
+ if terminalPayload != nil {
+ return
}
- eventType := strings.TrimSpace(gjson.Get(data, "type").String())
+ eventType := strings.TrimSpace(gjson.GetBytes(data, "type").String())
switch eventType {
case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
- return eventType, []byte(data), true
+ terminalType = eventType
+ terminalPayload = append([]byte(nil), data...)
}
+ })
+ if terminalPayload != nil {
+ return terminalType, terminalPayload, true
}
return "", nil, false
}
@@ -4651,21 +4788,20 @@ func (s *OpenAIGatewayService) writeOpenAINonStreamingProtocolError(resp *http.R
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
- lines := strings.Split(body, "\n")
- for _, line := range lines {
- data, ok := extractOpenAISSEDataLine(line)
- if !ok {
- continue
+ var finalResponse []byte
+ forEachOpenAISSEDataPayload(body, func(data []byte) {
+ if finalResponse != nil {
+ return
}
- if data == "" || data == "[DONE]" {
- continue
- }
- eventType := gjson.Get(data, "type").String()
+ eventType := gjson.GetBytes(data, "type").String()
if eventType == "response.done" || eventType == "response.completed" {
- if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
- return []byte(response.Raw), true
+ if response := gjson.GetBytes(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" {
+ finalResponse = []byte(response.Raw)
}
}
+ })
+ if finalResponse != nil {
+ return finalResponse, true
}
return nil, false
}
@@ -4677,21 +4813,15 @@ func reconstructResponseOutputFromSSE(bodyText string) ([]byte, bool) {
acc := apicompat.NewBufferedResponseAccumulator()
imageOutputs := make([]json.RawMessage, 0, 1)
seenImages := make(map[string]struct{})
- lines := strings.Split(bodyText, "\n")
- for _, line := range lines {
- data, ok := extractOpenAISSEDataLine(line)
- if !ok || data == "" || data == "[DONE]" {
- continue
- }
- if imageOutput, ok := extractImageGenerationOutputFromSSEData([]byte(data), seenImages); ok {
+ forEachOpenAISSEDataPayload(bodyText, func(data []byte) {
+ if imageOutput, ok := extractImageGenerationOutputFromSSEData(data, seenImages); ok {
imageOutputs = append(imageOutputs, imageOutput)
}
var event apicompat.ResponsesStreamEvent
- if err := json.Unmarshal([]byte(data), &event); err != nil {
- continue
+ if err := json.Unmarshal(data, &event); err == nil {
+ acc.ProcessEvent(&event)
}
- acc.ProcessEvent(&event)
- }
+ })
if !acc.HasContent() && len(imageOutputs) == 0 {
return nil, false
}
@@ -4744,17 +4874,9 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
- lines := strings.Split(body, "\n")
- for _, line := range lines {
- data, ok := extractOpenAISSEDataLine(line)
- if !ok {
- continue
- }
- if data == "" || data == "[DONE]" {
- continue
- }
- s.parseSSEUsageBytes([]byte(data), usage)
- }
+ forEachOpenAISSEDataPayload(body, func(data []byte) {
+ s.parseSSEUsageBytes(data, usage)
+ })
return usage
}
@@ -5036,8 +5158,14 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
+ if input == nil {
+ return errors.New("openai usage input is nil")
+ }
result := input.Result
- if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
+ if result == nil {
+ return errors.New("openai usage result is nil")
+ }
+ if s.rateLimitService != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
@@ -5074,6 +5202,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
multiplier = resolver.Resolve(ctx, user.ID, *apiKey.GroupID, apiKey.Group.RateMultiplier)
}
+ imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier)
var cost *CostBreakdown
var err error
@@ -5087,13 +5216,21 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" {
billingModel = input.OriginalModel
}
+ billingModels := usageBillingModelCandidates(
+ billingModel,
+ result.BillingModel,
+ input.ChannelMappedModel,
+ input.OriginalModel,
+ result.UpstreamModel,
+ result.Model,
+ )
serviceTier := ""
if result.ServiceTier != nil {
serviceTier = strings.TrimSpace(*result.ServiceTier)
}
- cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, tokens, serviceTier)
+ cost, err = s.calculateOpenAIRecordUsageCost(ctx, result, apiKey, billingModels, multiplier, imageMultiplier, tokens, serviceTier)
if err != nil {
- cost = &CostBreakdown{ActualCost: 0}
+ return err
}
// Determine billing type
@@ -5143,7 +5280,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
- usageLog.RateMultiplier = multiplier
+ if result.ImageCount > 0 {
+ usageLog.RateMultiplier = imageMultiplier
+ } else {
+ usageLog.RateMultiplier = multiplier
+ }
usageLog.AccountRateMultiplier = &accountRateMultiplier
usageLog.BillingType = billingType
usageLog.Stream = result.Stream
@@ -5224,14 +5365,45 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
ctx context.Context,
result *OpenAIForwardResult,
apiKey *APIKey,
+ billingModels []string,
+ multiplier float64,
+ imageMultiplier float64,
+ tokens UsageTokens,
+ serviceTier string,
+) (*CostBreakdown, error) {
+ billingModel := firstUsageBillingModel(billingModels)
+ if result != nil && result.ImageCount > 0 {
+ return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
+ }
+ if len(billingModels) == 0 || billingModel == "" {
+ return nil, errors.New("openai usage billing model is empty")
+ }
+ var lastErr error
+ for _, candidate := range billingModels {
+ candidate = strings.TrimSpace(candidate)
+ if candidate == "" {
+ continue
+ }
+ cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier)
+ if err == nil {
+ return cost, nil
+ }
+ lastErr = err
+ }
+ if lastErr == nil {
+ lastErr = errors.New("no non-empty billing model candidates")
+ }
+ return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr)
+}
+
+func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost(
+ ctx context.Context,
+ apiKey *APIKey,
billingModel string,
multiplier float64,
tokens UsageTokens,
serviceTier string,
) (*CostBreakdown, error) {
- if result != nil && result.ImageCount > 0 {
- return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil
- }
if s.resolver != nil && apiKey.Group != nil {
gid := apiKey.Group.ID
return s.billingService.CalculateCostUnified(CostInput{
@@ -5262,7 +5434,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
Ctx: ctx,
Model: billingModel,
GroupID: &gid,
- RequestCount: 1,
+ RequestCount: result.ImageCount,
SizeTier: result.ImageSize,
RateMultiplier: multiplier,
Resolver: s.resolver,
diff --git a/backend/internal/service/openai_image_generation_controls_test.go b/backend/internal/service/openai_image_generation_controls_test.go
new file mode 100644
index 00000000..76dc8053
--- /dev/null
+++ b/backend/internal/service/openai_image_generation_controls_test.go
@@ -0,0 +1,215 @@
+package service
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestOpenAIGatewayServiceForward_RejectsDisabledImageGenerationIntents(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ name string
+ body []byte
+ }{
+ {
+ name: "image model",
+ body: []byte(`{"model":"gpt-image-2","input":"draw"}`),
+ },
+ {
+ name: "image tool",
+ body: []byte(`{"model":"gpt-5.4","input":"draw","tools":[{"type":"image_generation"}]}`),
+ },
+ {
+ name: "image tool choice",
+ body: []byte(`{"model":"gpt-5.4","input":"draw","tool_choice":{"type":"image_generation"}}`),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ upstream := &httpUpstreamRecorder{}
+ svc := newOpenAIImageGenerationControlTestService(upstream)
+ c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
+ account := newOpenAIImageGenerationControlTestAccount()
+
+ result, err := svc.Forward(context.Background(), c, account, tt.body)
+
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+ require.Equal(t, "permission_error", gjson.GetBytes(recorder.Body.Bytes(), "error.type").String())
+ require.Nil(t, upstream.lastReq, "disabled image request must not reach upstream")
+ })
+ }
+}
+
+func TestOpenAIGatewayServiceForward_DisabledGroupAllowsTextOnlyResponses(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_text","model":"gpt-5.4","usage":{"input_tokens":3,"output_tokens":2}}`)),
+ },
+ }
+ svc := newOpenAIImageGenerationControlTestService(upstream)
+ c, recorder := newOpenAIImageGenerationControlTestContext(false, "unit-test-agent/1.0")
+ account := newOpenAIImageGenerationControlTestAccount()
+
+ result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, 3, result.Usage.InputTokens)
+ require.Equal(t, 2, result.Usage.OutputTokens)
+ require.Equal(t, 0, result.ImageCount)
+ require.NotNil(t, upstream.lastReq)
+}
+
+func TestOpenAIGatewayServiceForward_CodexImageInjectionRespectsGroupCapability(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ name string
+ allowImages bool
+ wantInjected bool
+ }{
+ {name: "disabled group skips injection", allowImages: false, wantInjected: false},
+ {name: "enabled group injects image tool", allowImages: true, wantInjected: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"resp_codex","model":"gpt-5.4","usage":{"input_tokens":1,"output_tokens":1}}`)),
+ },
+ }
+ svc := newOpenAIImageGenerationControlTestService(upstream)
+ c, _ := newOpenAIImageGenerationControlTestContext(tt.allowImages, "codex_cli_rs/0.98.0")
+ account := newOpenAIImageGenerationControlTestAccount()
+
+ result, err := svc.Forward(context.Background(), c, account, []byte(`{"model":"gpt-5.4","input":"write code","stream":false}`))
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, upstream.lastReq)
+ hasImageTool := gjson.GetBytes(upstream.lastBody, `tools.#(type=="image_generation")`).Exists()
+ require.Equal(t, tt.wantInjected, hasImageTool)
+ instructions := gjson.GetBytes(upstream.lastBody, "instructions").String()
+ require.Equal(t, tt.wantInjected, strings.Contains(instructions, "image_generation"))
+ })
+ }
+}
+
+func TestOpenAIGatewayServiceHandleResponsesImageOutputs_NonStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
+ c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{
+ "id":"resp_image_json",
+ "model":"gpt-5.4",
+ "output":[{"id":"ig_json_1","type":"image_generation_call","result":"final-image"}],
+ "usage":{"input_tokens":7,"output_tokens":3,"output_tokens_details":{"image_tokens":2}}
+ }`)),
+ }
+
+ result, err := svc.handleNonStreamingResponse(context.Background(), resp, c, &Account{ID: 1, Type: AccountTypeAPIKey}, "gpt-5.4", "gpt-5.4")
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.imageCount)
+ require.NotNil(t, result.usage)
+ require.Equal(t, 7, result.usage.InputTokens)
+ require.Equal(t, 3, result.usage.OutputTokens)
+ require.Equal(t, 2, result.usage.ImageOutputTokens)
+}
+
+func TestOpenAIGatewayServiceHandleResponsesImageOutputs_Streaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := newOpenAIImageGenerationControlTestService(&httpUpstreamRecorder{})
+ c, _ := newOpenAIImageGenerationControlTestContext(true, "unit-test-agent/1.0")
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_image_stream\",\"model\":\"gpt-5.5\",\"output\":[{\"id\":\"ig_stream_1\",\"type\":\"image_generation_call\",\"result\":\"final-image\"}],\"usage\":{\"input_tokens\":11,\"output_tokens\":5,\"output_tokens_details\":{\"image_tokens\":4}}}}\n\n",
+ )),
+ }
+
+ result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "gpt-5.5", "gpt-5.5")
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.imageCount)
+ require.NotNil(t, result.usage)
+ require.Equal(t, 11, result.usage.InputTokens)
+ require.Equal(t, 5, result.usage.OutputTokens)
+ require.Equal(t, 4, result.usage.ImageOutputTokens)
+}
+
+func newOpenAIImageGenerationControlTestService(upstream *httpUpstreamRecorder) *OpenAIGatewayService {
+ cfg := &config.Config{}
+ return &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: upstream,
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+}
+
+func newOpenAIImageGenerationControlTestContext(allowImages bool, userAgent string) (*gin.Context, *httptest.ResponseRecorder) {
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", userAgent)
+ groupID := int64(4242)
+ c.Set("api_key", &APIKey{
+ ID: 2424,
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ AllowImageGeneration: allowImages,
+ RateMultiplier: 1,
+ ImageRateMultiplier: 1,
+ },
+ })
+ return c, recorder
+}
+
+func newOpenAIImageGenerationControlTestAccount() *Account {
+ return &Account{
+ ID: 5151,
+ Name: "openai-image-controls",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ }
+}
diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go
index 3da76525..04be5164 100644
--- a/backend/internal/service/openai_images.go
+++ b/backend/internal/service/openai_images.go
@@ -16,6 +16,7 @@ import (
"net/textproto"
"strconv"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -468,14 +469,54 @@ func isOpenAINativeImageOption(name string) bool {
}
func normalizeOpenAIImageSizeTier(size string) string {
- switch strings.ToLower(strings.TrimSpace(size)) {
+ trimmed := strings.TrimSpace(size)
+ normalized := strings.ToLower(trimmed)
+ switch normalized {
+ case "", "auto":
+ return "2K"
case "1024x1024":
return "1K"
- case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "", "auto":
+ case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
return "2K"
- default:
+ case "3840x2160", "2160x3840":
+ return "4K"
+ }
+ width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
+ if !ok {
return "2K"
}
+ return classifyUnknownOpenAIImageSizeTier(width, height)
+}
+
+const (
+ openAIImage2KMaxPixels = 2560 * 1440
+)
+
+func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
+ trimmed := strings.TrimSpace(size)
+ parts := strings.Split(strings.ToLower(trimmed), "x")
+ if len(parts) != 2 {
+ return 0, 0, false
+ }
+ width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
+ if err != nil {
+ return 0, 0, false
+ }
+ height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
+ if err != nil {
+ return 0, 0, false
+ }
+ if width <= 0 || height <= 0 {
+ return 0, 0, false
+ }
+ return width, height, true
+}
+
+func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
+ if height > 0 && width > openAIImage2KMaxPixels/height {
+ return "4K"
+ }
+ return "2K"
}
func (s *OpenAIGatewayService) ForwardImages(
@@ -535,11 +576,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
setOpsUpstreamRequestBody(c, forwardBody)
}
- token, _, err := s.GetAccessToken(ctx, account)
+ upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
+ defer releaseUpstreamCtx()
+
+ token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil {
return nil, err
}
- upstreamReq, err := s.buildOpenAIImagesRequest(ctx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
+ upstreamReq, err := s.buildOpenAIImagesRequest(upstreamCtx, c, account, forwardBody, forwardContentType, token, parsed.Endpoint)
if err != nil {
return nil, err
}
@@ -582,14 +626,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
Kind: "failover",
Message: upstreamMsg,
})
- s.handleFailoverSideEffects(ctx, resp, account)
+ s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
- return s.handleErrorResponse(ctx, resp, c, account, forwardBody)
+ return s.handleErrorResponse(upstreamCtx, resp, c, account, forwardBody)
}
defer func() { _ = resp.Body.Close() }()
@@ -599,6 +643,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
if parsed.Stream && isEventStreamResponse(resp.Header) {
streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil {
+ if streamCount > 0 {
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: streamUsage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: ttft,
+ ImageCount: streamCount,
+ ImageSize: parsed.SizeTier,
+ }, err
+ }
return nil, err
}
usage = streamUsage
@@ -807,66 +865,205 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
}
- reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{}
- imageCount := 0
+ imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
+ clientDisconnected := false
+ lastDownstreamWriteAt := time.Now()
var fallbackBody bytes.Buffer
fallbackBytes := int64(0)
fallbackLimit := resolveUpstreamResponseReadLimit(s.cfg)
seenSSEData := false
fallbackTooLarge := false
+ var sseData openAISSEDataAccumulator
+
+ processSSEData := func(dataBytes []byte) {
+ seenSSEData = true
+ fallbackBody.Reset()
+ fallbackBytes = 0
+ mergeOpenAIUsage(&usage, dataBytes)
+ imageCounter.AddSSEData(dataBytes)
+ }
+
+ flushSSEEvent := func() {
+ sseData.Flush(processSSEData)
+ }
+
+ processLine := func(line []byte) {
+ if len(line) == 0 {
+ return
+ }
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if !clientDisconnected {
+ if _, writeErr := c.Writer.Write(line); writeErr != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
+ } else {
+ flusher.Flush()
+ lastDownstreamWriteAt = time.Now()
+ }
+ }
+
+ trimmedLine := strings.TrimRight(string(line), "\r\n")
+ if _, ok := extractOpenAISSEDataLine(trimmedLine); ok || strings.TrimSpace(trimmedLine) == "" {
+ sseData.AddLine(trimmedLine, processSSEData)
+ return
+ }
+ if !seenSSEData && !fallbackTooLarge {
+ fallbackBytes += int64(len(line))
+ if fallbackBytes <= fallbackLimit {
+ _, _ = fallbackBody.Write(line)
+ } else {
+ fallbackTooLarge = true
+ fallbackBody.Reset()
+ }
+ }
+ }
+
+ finalizeFallbackBody := func() {
+ if seenSSEData || fallbackBody.Len() == 0 {
+ return
+ }
+ body := bytes.TrimSpace(fallbackBody.Bytes())
+ if len(body) == 0 {
+ return
+ }
+ mergeOpenAIUsage(&usage, body)
+ imageCounter.AddJSONResponse(body)
+ }
+
+ streamInterval := s.openAIImageStreamDataInterval()
+ keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
+ if streamInterval <= 0 && keepaliveInterval <= 0 {
+ reader := bufio.NewReader(resp.Body)
+ for {
+ line, err := reader.ReadBytes('\n')
+ processLine(line)
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ flushSSEEvent()
+ return usage, imageCounter.Count(), firstTokenMs, err
+ }
+ }
+ flushSSEEvent()
+ finalizeFallbackBody()
+ return usage, imageCounter.Count(), firstTokenMs, nil
+ }
+
+ type readEvent struct {
+ line []byte
+ err error
+ }
+ events := make(chan readEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev readEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ reader := bufio.NewReader(resp.Body)
+ for {
+ line, err := reader.ReadBytes('\n')
+ if len(line) > 0 {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ }
+ if len(line) > 0 && !sendEvent(readEvent{line: line}) {
+ return
+ }
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ _ = sendEvent(readEvent{err: err})
+ return
+ }
+ }
+ }()
+ defer close(done)
+
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ var keepaliveTicker *time.Ticker
+ if keepaliveInterval > 0 {
+ keepaliveTicker = time.NewTicker(keepaliveInterval)
+ defer keepaliveTicker.Stop()
+ }
+ var keepaliveCh <-chan time.Time
+ if keepaliveTicker != nil {
+ keepaliveCh = keepaliveTicker.C
+ }
for {
- line, err := reader.ReadBytes('\n')
- if len(line) > 0 {
- if firstTokenMs == nil {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ flushSSEEvent()
+ finalizeFallbackBody()
+ return usage, imageCounter.Count(), firstTokenMs, nil
}
- if _, writeErr := c.Writer.Write(line); writeErr != nil {
- return OpenAIUsage{}, 0, firstTokenMs, writeErr
+ if ev.err != nil {
+ flushSSEEvent()
+ return usage, imageCounter.Count(), firstTokenMs, ev.err
+ }
+ processLine(ev.line)
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
+ continue
+ }
+ if clientDisconnected {
+ return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
+ }
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
+ _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
+ return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
+ case <-keepaliveCh:
+ if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
+ continue
+ }
+ if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected during keepalive, continue draining upstream for billing")
+ continue
}
flusher.Flush()
+ lastDownstreamWriteAt = time.Now()
+ }
+ }
+}
- if data, ok := extractOpenAISSEDataLine(strings.TrimRight(string(line), "\r\n")); ok {
- if data != "" && data != "[DONE]" {
- seenSSEData = true
- fallbackBody.Reset()
- fallbackBytes = 0
- dataBytes := []byte(data)
- mergeOpenAIUsage(&usage, dataBytes)
- if count := extractOpenAIImagesBillableCountFromJSONBytes(dataBytes); count > imageCount {
- imageCount = count
- }
- }
- } else if !seenSSEData && !fallbackTooLarge {
- fallbackBytes += int64(len(line))
- if fallbackBytes <= fallbackLimit {
- _, _ = fallbackBody.Write(line)
- } else {
- fallbackTooLarge = true
- fallbackBody.Reset()
- }
- }
- }
- if err == io.EOF {
- break
- }
- if err != nil {
- return OpenAIUsage{}, 0, firstTokenMs, err
- }
+func (s *OpenAIGatewayService) openAIImageStreamDataInterval() time.Duration {
+ if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamDataIntervalTimeout <= 0 {
+ return 0
}
- if !seenSSEData && fallbackBody.Len() > 0 {
- body := bytes.TrimSpace(fallbackBody.Bytes())
- if len(body) > 0 {
- mergeOpenAIUsage(&usage, body)
- if count := extractOpenAIImagesBillableCountFromJSONBytes(body); count > imageCount {
- imageCount = count
- }
- }
+ return time.Duration(s.cfg.Gateway.ImageStreamDataIntervalTimeout) * time.Second
+}
+
+func (s *OpenAIGatewayService) openAIImageStreamKeepaliveInterval() time.Duration {
+ if s == nil || s.cfg == nil || s.cfg.Gateway.ImageStreamKeepaliveInterval <= 0 {
+ return 0
}
- return usage, imageCount, firstTokenMs, nil
+ return time.Duration(s.cfg.Gateway.ImageStreamKeepaliveInterval) * time.Second
}
func extractOpenAIImagesBillableCountFromJSONBytes(body []byte) int {
@@ -913,14 +1110,7 @@ func mergeOpenAIUsage(dst *OpenAIUsage, body []byte) {
}
func extractOpenAIImageCountFromJSONBytes(body []byte) int {
- if len(body) == 0 || !gjson.ValidBytes(body) {
- return 0
- }
- data := gjson.GetBytes(body, "data")
- if data.Exists() && data.IsArray() {
- return len(data.Array())
- }
- return 0
+ return countOpenAIResponseImageOutputsFromJSONBytes(body)
}
type openAIImagePointerInfo struct {
diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go
index 64d995e1..25cd8228 100644
--- a/backend/internal/service/openai_images_responses.go
+++ b/backend/internal/service/openai_images_responses.go
@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -361,21 +362,21 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
var (
fallbackResults []openAIResponsesImageResult
fallbackSeen = make(map[string]struct{})
+ finalResults []openAIResponsesImageResult
+ finalMeta openAIResponsesImageResult
+ collectErr error
createdAt int64
usageRaw []byte
foundFinal bool
responseMeta openAIResponsesImageResult
)
- for _, line := range bytes.Split(body, []byte("\n")) {
- line = bytes.TrimRight(line, "\r")
- data, ok := extractOpenAISSEDataLine(string(line))
- if !ok || data == "" || data == "[DONE]" {
- continue
+ forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
+ if collectErr != nil || len(finalResults) > 0 {
+ return
}
- payload := []byte(data)
if !gjson.ValidBytes(payload) {
- continue
+ return
}
if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok {
mergeOpenAIResponsesImageMeta(&responseMeta, meta)
@@ -388,7 +389,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.output_item.done":
result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload)
if err != nil {
- return nil, 0, nil, openAIResponsesImageResult{}, false, err
+ collectErr = err
+ return
}
if ok {
mergeOpenAIResponsesImageMeta(&result, responseMeta)
@@ -397,7 +399,8 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
case "response.completed":
results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload)
if err != nil {
- return nil, 0, nil, openAIResponsesImageResult{}, false, err
+ collectErr = err
+ return
}
foundFinal = true
if completedAt > 0 {
@@ -408,14 +411,24 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
}
if len(results) > 0 {
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
- return results, createdAt, usageRaw, firstMeta, true, nil
+ finalResults = results
+ finalMeta = firstMeta
+ return
}
if len(fallbackResults) > 0 {
firstMeta = fallbackResults[0]
mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta)
- return fallbackResults, createdAt, usageRaw, firstMeta, true, nil
+ finalResults = fallbackResults
+ finalMeta = firstMeta
+ return
}
}
+ })
+ if collectErr != nil {
+ return nil, 0, nil, openAIResponsesImageResult{}, false, collectErr
+ }
+ if len(finalResults) > 0 {
+ return finalResults, createdAt, usageRaw, finalMeta, true, nil
}
if len(fallbackResults) > 0 {
@@ -505,6 +518,30 @@ func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flus
return nil
}
+func (s *OpenAIGatewayService) tryWriteOpenAIImagesStreamEvent(
+ c *gin.Context,
+ flusher http.Flusher,
+ clientDisconnected *bool,
+ lastWriteAt *time.Time,
+ eventName string,
+ payload []byte,
+) bool {
+ if clientDisconnected != nil && *clientDisconnected {
+ return false
+ }
+ if err := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); err != nil {
+ if clientDisconnected != nil {
+ *clientDisconnected = true
+ }
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream client disconnected, continue draining upstream for billing")
+ return false
+ }
+ if lastWriteAt != nil {
+ *lastWriteAt = time.Now()
+ }
+ return true
+}
+
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
resp *http.Response,
c *gin.Context,
@@ -517,15 +554,9 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
}
var usage OpenAIUsage
- for _, line := range bytes.Split(body, []byte("\n")) {
- line = bytes.TrimRight(line, "\r")
- data, ok := extractOpenAISSEDataLine(string(line))
- if !ok || data == "" || data == "[DONE]" {
- continue
- }
- dataBytes := []byte(data)
- s.parseSSEUsageBytes(dataBytes, &usage)
- }
+ forEachOpenAISSEDataPayload(string(body), func(data []byte) {
+ s.parseSSEUsageBytes(data, &usage)
+ })
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil {
return OpenAIUsage{}, 0, err
@@ -570,7 +601,6 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
format = "b64_json"
}
- reader := bufio.NewReader(resp.Body)
usage := OpenAIUsage{}
imageCount := 0
var firstTokenMs *int
@@ -579,141 +609,307 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
pendingSeen := make(map[string]struct{})
streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)}
var createdAt int64
+ clientDisconnected := false
+ lastDownstreamWriteAt := time.Now()
+ var sseData openAISSEDataAccumulator
+ var processDataErr error
+ processDataDone := false
+
+ processData := func(dataBytes []byte) {
+ if processDataDone || processDataErr != nil {
+ return
+ }
+ if firstTokenMs == nil {
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ s.parseSSEUsageBytes(dataBytes, &usage)
+ if !gjson.ValidBytes(dataBytes) {
+ return
+ }
+ if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
+ mergeOpenAIResponsesImageMeta(&streamMeta, meta)
+ if eventCreatedAt > 0 {
+ createdAt = eventCreatedAt
+ }
+ }
+ switch gjson.GetBytes(dataBytes, "type").String() {
+ case "response.image_generation_call.partial_image":
+ b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
+ if b64 == "" {
+ return
+ }
+ eventName := streamPrefix + ".partial_image"
+ partialMeta := streamMeta
+ mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
+ OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
+ Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
+ })
+ payload := buildOpenAIImagesStreamPartialPayload(
+ eventName,
+ b64,
+ gjson.GetBytes(dataBytes, "partial_image_index").Int(),
+ format,
+ createdAt,
+ partialMeta,
+ )
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
+ case "response.output_item.done":
+ img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
+ if extractErr != nil {
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
+ processDataErr = extractErr
+ processDataDone = true
+ return
+ }
+ if !ok {
+ return
+ }
+ mergeOpenAIResponsesImageMeta(&streamMeta, img)
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ key := openAIResponsesImageResultKey(itemID, img)
+ if _, exists := emitted[key]; exists {
+ return
+ }
+ if _, exists := pendingSeen[key]; exists {
+ return
+ }
+ pendingSeen[key] = struct{}{}
+ pendingResults = append(pendingResults, img)
+ case "response.completed":
+ results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
+ if extractErr != nil {
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
+ processDataErr = extractErr
+ processDataDone = true
+ return
+ }
+ mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
+ finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
+ finalSeen := make(map[string]struct{})
+ for _, img := range results {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
+ }
+ for _, img := range pendingResults {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
+ }
+ if len(finalResults) == 0 {
+ outputErr := fmt.Errorf("upstream did not return image output")
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(outputErr.Error()))
+ processDataErr = outputErr
+ processDataDone = true
+ return
+ }
+ eventName := streamPrefix + ".completed"
+ for _, img := range finalResults {
+ key := openAIResponsesImageResultKey("", img)
+ if _, exists := emitted[key]; exists {
+ continue
+ }
+ payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
+ emitted[key] = struct{}{}
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
+ }
+ imageCount = len(emitted)
+ processDataDone = true
+ }
+ }
+
+ processLine := func(line []byte) (bool, error) {
+ if len(line) == 0 {
+ return false, nil
+ }
+ sseData.AddLine(string(line), processData)
+ if processDataErr != nil {
+ return true, processDataErr
+ }
+ return processDataDone, nil
+ }
+
+ flushData := func() (bool, error) {
+ sseData.Flush(processData)
+ if processDataErr != nil {
+ return true, processDataErr
+ }
+ return processDataDone, nil
+ }
+
+ finalizePending := func() error {
+ if imageCount > 0 {
+ return nil
+ }
+ if len(pendingResults) > 0 {
+ eventName := streamPrefix + ".completed"
+ for _, img := range pendingResults {
+ mergeOpenAIResponsesImageMeta(&img, streamMeta)
+ key := openAIResponsesImageResultKey("", img)
+ if _, exists := emitted[key]; exists {
+ continue
+ }
+ payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
+ emitted[key] = struct{}{}
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
+ }
+ imageCount = len(emitted)
+ return nil
+ }
+
+ streamErr := fmt.Errorf("stream disconnected before image generation completed")
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
+ return streamErr
+ }
+
+ streamInterval := s.openAIImageStreamDataInterval()
+ keepaliveInterval := s.openAIImageStreamKeepaliveInterval()
+ if streamInterval <= 0 && keepaliveInterval <= 0 {
+ reader := bufio.NewReader(resp.Body)
+ for {
+ line, err := reader.ReadBytes('\n')
+ done, processErr := processLine(line)
+ if processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ }
+ if done {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ if done, processErr := flushData(); processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ } else if done {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
+ return usage, imageCount, firstTokenMs, err
+ }
+ }
+ if done, processErr := flushData(); processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ } else if done {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ if err := finalizePending(); err != nil {
+ return usage, imageCount, firstTokenMs, err
+ }
+ return usage, imageCount, firstTokenMs, nil
+ }
+
+ type readEvent struct {
+ line []byte
+ err error
+ }
+ events := make(chan readEvent, 16)
+ done := make(chan struct{})
+ sendEvent := func(ev readEvent) bool {
+ select {
+ case events <- ev:
+ return true
+ case <-done:
+ return false
+ }
+ }
+ var lastReadAt int64
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ go func() {
+ defer close(events)
+ reader := bufio.NewReader(resp.Body)
+ for {
+ line, err := reader.ReadBytes('\n')
+ if len(line) > 0 {
+ atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
+ }
+ if len(line) > 0 && !sendEvent(readEvent{line: line}) {
+ return
+ }
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ _ = sendEvent(readEvent{err: err})
+ return
+ }
+ }
+ }()
+ defer close(done)
+
+ var intervalTicker *time.Ticker
+ if streamInterval > 0 {
+ intervalTicker = time.NewTicker(streamInterval)
+ defer intervalTicker.Stop()
+ }
+ var intervalCh <-chan time.Time
+ if intervalTicker != nil {
+ intervalCh = intervalTicker.C
+ }
+
+ var keepaliveTicker *time.Ticker
+ if keepaliveInterval > 0 {
+ keepaliveTicker = time.NewTicker(keepaliveInterval)
+ defer keepaliveTicker.Stop()
+ }
+ var keepaliveCh <-chan time.Time
+ if keepaliveTicker != nil {
+ keepaliveCh = keepaliveTicker.C
+ }
for {
- line, err := reader.ReadBytes('\n')
- if len(line) > 0 {
- trimmedLine := strings.TrimRight(string(line), "\r\n")
- data, ok := extractOpenAISSEDataLine(trimmedLine)
- if ok && data != "" && data != "[DONE]" {
- if firstTokenMs == nil {
- ms := int(time.Since(startTime).Milliseconds())
- firstTokenMs = &ms
+ select {
+ case ev, ok := <-events:
+ if !ok {
+ if done, processErr := flushData(); processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ } else if done {
+ return usage, imageCount, firstTokenMs, nil
}
- dataBytes := []byte(data)
- s.parseSSEUsageBytes(dataBytes, &usage)
- if gjson.ValidBytes(dataBytes) {
- if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok {
- mergeOpenAIResponsesImageMeta(&streamMeta, meta)
- if eventCreatedAt > 0 {
- createdAt = eventCreatedAt
- }
- }
- switch gjson.GetBytes(dataBytes, "type").String() {
- case "response.image_generation_call.partial_image":
- b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String())
- if b64 != "" {
- eventName := streamPrefix + ".partial_image"
- partialMeta := streamMeta
- mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{
- OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()),
- Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()),
- })
- payload := buildOpenAIImagesStreamPartialPayload(
- eventName,
- b64,
- gjson.GetBytes(dataBytes, "partial_image_index").Int(),
- format,
- createdAt,
- partialMeta,
- )
- if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
- return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
- }
- }
- case "response.output_item.done":
- img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes)
- if extractErr != nil {
- _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
- return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
- }
- if !ok {
- break
- }
- mergeOpenAIResponsesImageMeta(&streamMeta, img)
- mergeOpenAIResponsesImageMeta(&img, streamMeta)
- key := openAIResponsesImageResultKey(itemID, img)
- if _, exists := emitted[key]; exists {
- break
- }
- if _, exists := pendingSeen[key]; exists {
- break
- }
- pendingSeen[key] = struct{}{}
- pendingResults = append(pendingResults, img)
- case "response.completed":
- results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes)
- if extractErr != nil {
- _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error()))
- return OpenAIUsage{}, imageCount, firstTokenMs, extractErr
- }
- mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta)
- finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults))
- finalSeen := make(map[string]struct{})
- for _, img := range results {
- mergeOpenAIResponsesImageMeta(&img, streamMeta)
- appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
- }
- for _, img := range pendingResults {
- mergeOpenAIResponsesImageMeta(&img, streamMeta)
- appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img)
- }
- if len(finalResults) == 0 {
- err = fmt.Errorf("upstream did not return image output")
- _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
- return OpenAIUsage{}, imageCount, firstTokenMs, err
- }
- eventName := streamPrefix + ".completed"
- for _, img := range finalResults {
- key := openAIResponsesImageResultKey("", img)
- if _, exists := emitted[key]; exists {
- continue
- }
- payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw)
- if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
- return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
- }
- emitted[key] = struct{}{}
- }
- imageCount = len(emitted)
- return usage, imageCount, firstTokenMs, nil
- }
+ if err := finalizePending(); err != nil {
+ return usage, imageCount, firstTokenMs, err
}
+ return usage, imageCount, firstTokenMs, nil
}
- }
- if err == io.EOF {
- break
- }
- if err != nil {
- _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
- return OpenAIUsage{}, imageCount, firstTokenMs, err
- }
- }
-
- if imageCount > 0 {
- return usage, imageCount, firstTokenMs, nil
- }
- if len(pendingResults) > 0 {
- eventName := streamPrefix + ".completed"
- for _, img := range pendingResults {
- mergeOpenAIResponsesImageMeta(&img, streamMeta)
- key := openAIResponsesImageResultKey("", img)
- if _, exists := emitted[key]; exists {
+ if ev.err != nil {
+ if done, processErr := flushData(); processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ } else if done {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
+ return usage, imageCount, firstTokenMs, ev.err
+ }
+ done, processErr := processLine(ev.line)
+ if processErr != nil {
+ return usage, imageCount, firstTokenMs, processErr
+ }
+ if done {
+ return usage, imageCount, firstTokenMs, nil
+ }
+ case <-intervalCh:
+ lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
+ if time.Since(lastRead) < streamInterval {
continue
}
- payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil)
- if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil {
- return OpenAIUsage{}, imageCount, firstTokenMs, writeErr
+ if clientDisconnected {
+ return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
- emitted[key] = struct{}{}
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
+ s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
+ return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
+ case <-keepaliveCh:
+ if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
+ continue
+ }
+ if _, writeErr := io.WriteString(c.Writer, ":\n\n"); writeErr != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream client disconnected during keepalive, continue draining upstream for billing")
+ continue
+ }
+ flusher.Flush()
+ lastDownstreamWriteAt = time.Now()
}
- imageCount = len(emitted)
- return usage, imageCount, firstTokenMs, nil
}
-
- streamErr := fmt.Errorf("stream disconnected before image generation completed")
- _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error()))
- return OpenAIUsage{}, imageCount, firstTokenMs, streamErr
}
func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
@@ -752,7 +948,10 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
)
}
- token, _, err := s.GetAccessToken(ctx, account)
+ upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
+ defer releaseUpstreamCtx()
+
+ token, _, err := s.GetAccessToken(upstreamCtx, account)
if err != nil {
return nil, err
}
@@ -763,7 +962,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
}
setOpsUpstreamRequestBody(c, responsesBody)
- upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
+ upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
if err != nil {
return nil, err
}
@@ -808,14 +1007,14 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
Kind: "failover",
Message: upstreamMsg,
})
- s.handleFailoverSideEffects(ctx, resp, account)
+ s.handleFailoverSideEffects(upstreamCtx, resp, account)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,
RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode),
}
}
- return s.handleErrorResponse(ctx, resp, c, account, responsesBody)
+ return s.handleErrorResponse(upstreamCtx, resp, c, account, responsesBody)
}
defer func() { _ = resp.Body.Close() }()
@@ -827,6 +1026,20 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
if parsed.Stream {
usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
if err != nil {
+ if imageCount > 0 {
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: requestModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ }, err
+ }
return nil, err
}
} else {
diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go
index 681e0e8e..fa4a4415 100644
--- a/backend/internal/service/openai_images_test.go
+++ b/backend/internal/service/openai_images_test.go
@@ -3,6 +3,7 @@ package service
import (
"bytes"
"context"
+ "errors"
"io"
"mime/multipart"
"net/http"
@@ -17,6 +18,20 @@ import (
"github.com/tidwall/gjson"
)
+type failingOpenAIImageWriter struct {
+ gin.ResponseWriter
+ failAfter int
+ writes int
+}
+
+func (w *failingOpenAIImageWriter) Write(p []byte) (int, error) {
+ if w.writes >= w.failAfter {
+ return 0, errors.New("write failed: client disconnected")
+ }
+ w.writes++
+ return w.ResponseWriter.Write(p)
+}
+
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","stream":true}`)
@@ -75,6 +90,100 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T
require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability)
}
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCustomSizes(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ size string
+ wantTier string
+ }{
+ {size: "1024x1024", wantTier: "1K"},
+ {size: "1536x1024", wantTier: "2K"},
+ {size: "1024x1536", wantTier: "2K"},
+ {size: "2048x2048", wantTier: "2K"},
+ {size: "2048x1152", wantTier: "2K"},
+ {size: "3840x2160", wantTier: "4K"},
+ {size: "2160x3840", wantTier: "4K"},
+ {size: "1024X768", wantTier: "2K"},
+ {size: "1280x768", wantTier: "2K"},
+ {size: "2560x1440", wantTier: "2K"},
+ {size: "2560x1600", wantTier: "4K"},
+ {size: "auto", wantTier: "2K"},
+ }
+
+ svc := &OpenAIGatewayService{}
+ for _, tt := range tests {
+ t.Run(tt.size, func(t *testing.T) {
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, tt.size, parsed.Size)
+ require.Equal(t, tt.wantTier, parsed.SizeTier)
+ })
+ }
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPassthrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ size string
+ wantTier string
+ }{
+ {size: "2048x1153", wantTier: "2K"},
+ {size: "4096x1024", wantTier: "4K"},
+ {size: "3840x1024", wantTier: "4K"},
+ {size: "512x512", wantTier: "2K"},
+ {size: "invalid", wantTier: "2K"},
+ {size: "999999999999999999999999999x2", wantTier: "2K"},
+ }
+
+ svc := &OpenAIGatewayService{}
+ for _, tt := range tests {
+ t.Run(tt.size, func(t *testing.T) {
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"` + tt.size + `"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, tt.size, parsed.Size)
+ require.Equal(t, tt.wantTier, parsed.SizeTier)
+ })
+ }
+}
+
+func TestOpenAIGatewayServiceParseOpenAIImagesRequest_LegacyImageModelUnknownSizePassthrough(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-1.5","prompt":"draw a cat","size":"2048x1152"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+ require.NotNil(t, parsed)
+ require.Equal(t, "2048x1152", parsed.Size)
+ require.Equal(t, "2K", parsed.SizeTier)
+}
+
func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -543,6 +652,57 @@ func TestOpenAIGatewayServiceForwardImages_APIKeyStreamRawJSONEventStreamFallbac
require.Equal(t, "ZmluYWw=", gjson.Get(rec.Body.String(), "data.0.b64_json").String())
}
+func TestOpenAIGatewayServiceForwardImages_APIKeyStreamMultilineSSEDataBillsImage(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream_multiline"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"image_generation.completed\",\n" +
+ "data: \"usage\":{\"input_tokens\":10,\"output_tokens\":18,\"output_tokens_details\":{\"image_tokens\":8}},\n" +
+ "data: \"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 8,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, result.Stream)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, 10, result.Usage.InputTokens)
+ require.Equal(t, 18, result.Usage.OutputTokens)
+ require.Equal(t, 8, result.Usage.ImageOutputTokens)
+}
+
func TestExtractOpenAIImagesBillableCountFromJSONBytes_CompletedEvent(t *testing.T) {
body := []byte(`{"type":"image_generation.completed","b64_json":"ZmluYWw=","usage":{"input_tokens":10,"output_tokens":18}}`)
@@ -686,6 +846,61 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *tes
require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists())
}
+func TestOpenAIGatewayServiceForwardImages_APIKeyStreamingDrainsAfterClientDisconnect(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+ c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ ImageStreamDataIntervalTimeout: 1,
+ ImageStreamKeepaliveInterval: 0,
+ },
+ },
+ httpUpstream: &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream_disconnect_apikey"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"image_generation.partial_image\",\"b64_json\":\"cGFydGlhbA==\"}\n\n" +
+ "data: {\"type\":\"image_generation.completed\",\"usage\":{\"input_tokens\":3,\"output_tokens\":4,\"output_tokens_details\":{\"image_tokens\":2}},\"b64_json\":\"ZmluYWw=\",\"output_format\":\"png\"}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ account := &Account{
+ ID: 8,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "test-api-key",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, 3, result.Usage.InputTokens)
+ require.Equal(t, 4, result.Usage.OutputTokens)
+ require.Equal(t, 2, result.Usage.ImageOutputTokens)
+}
+
func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -901,6 +1116,23 @@ func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testi
require.JSONEq(t, `{"images":1}`, string(usageRaw))
}
+func TestCollectOpenAIImagesFromResponsesBody_MultilineSSE(t *testing.T) {
+ body := []byte(
+ "data: {\"type\":\"response.completed\",\n" +
+ "data: \"response\":{\"created_at\":1710000010,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )
+
+ results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body)
+ require.NoError(t, err)
+ require.True(t, foundFinal)
+ require.Equal(t, int64(1710000010), createdAt)
+ require.Len(t, results, 1)
+ require.Equal(t, "ZmluYWw=", results[0].Result)
+ require.Equal(t, "png", firstMeta.OutputFormat)
+ require.JSONEq(t, `{"images":1}`, string(usageRaw))
+}
+
func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
@@ -957,3 +1189,116 @@ func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFa
require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
require.NotContains(t, rec.Body.String(), "event: error")
}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesMultilineSSE(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"b64_json"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+
+ svc := &OpenAIGatewayService{}
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ svc.httpUpstream = &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream_multiline_oauth"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.completed\",\n" +
+ "data: \"response\":{\"created_at\":1710000011,\"usage\":{\"input_tokens\":6,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"TXVsdGlsaW5l\",\"output_format\":\"png\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+
+ account := &Account{
+ ID: 11,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, result.Stream)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, 6, result.Usage.InputTokens)
+ require.Equal(t, 10, result.Usage.OutputTokens)
+ require.Equal(t, 5, result.Usage.ImageOutputTokens)
+ events := parseOpenAIImageTestSSEEvents(rec.Body.String())
+ completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed")
+ require.True(t, ok)
+ require.Equal(t, "TXVsdGlsaW5l", gjson.Get(completed.Data, "b64_json").String())
+ require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw)
+ require.NotContains(t, rec.Body.String(), "event: error")
+}
+
+func TestOpenAIGatewayServiceForwardImages_OAuthStreamingDrainsAfterClientDisconnect(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`)
+
+ req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = req
+ c.Writer = &failingOpenAIImageWriter{ResponseWriter: c.Writer, failAfter: 1}
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ ImageStreamDataIntervalTimeout: 1,
+ ImageStreamKeepaliveInterval: 0,
+ },
+ },
+ }
+ parsed, err := svc.ParseOpenAIImagesRequest(c, body)
+ require.NoError(t, err)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ "X-Request-Id": []string{"req_img_stream_disconnect_oauth"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\"}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000009,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc.httpUpstream = upstream
+
+ account := &Account{
+ ID: 9,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, result.Stream)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, 5, result.Usage.InputTokens)
+ require.Equal(t, 9, result.Usage.OutputTokens)
+ require.Equal(t, 4, result.Usage.ImageOutputTokens)
+}
diff --git a/backend/internal/service/openai_model_alias.go b/backend/internal/service/openai_model_alias.go
new file mode 100644
index 00000000..2fa2c90e
--- /dev/null
+++ b/backend/internal/service/openai_model_alias.go
@@ -0,0 +1,137 @@
+package service
+
+import "strings"
+
+func lastOpenAIModelSegment(model string) string {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ return ""
+ }
+ if strings.Contains(model, "/") {
+ parts := strings.Split(model, "/")
+ model = parts[len(parts)-1]
+ }
+ return strings.TrimSpace(model)
+}
+
+func canonicalizeOpenAIModelAliasSpelling(model string) string {
+ model = strings.ToLower(lastOpenAIModelSegment(model))
+ if model == "" {
+ return ""
+ }
+
+ normalized := strings.ReplaceAll(model, "_", "-")
+ normalized = strings.Join(strings.Fields(normalized), "-")
+ for strings.Contains(normalized, "--") {
+ normalized = strings.ReplaceAll(normalized, "--", "-")
+ }
+
+ if strings.HasPrefix(normalized, "gpt5") {
+ normalized = "gpt-5" + strings.TrimPrefix(normalized, "gpt5")
+ }
+ if !strings.HasPrefix(normalized, "gpt-") && !strings.Contains(normalized, "codex") {
+ return ""
+ }
+
+ replacements := []struct {
+ from string
+ to string
+ }{
+ {"gpt-5.4mini", "gpt-5.4-mini"},
+ {"gpt-5.4nano", "gpt-5.4-nano"},
+ {"gpt-5.3-codexspark", "gpt-5.3-codex-spark"},
+ {"gpt-5.3codexspark", "gpt-5.3-codex-spark"},
+ {"gpt-5.3codex", "gpt-5.3-codex"},
+ }
+ for _, replacement := range replacements {
+ normalized = strings.ReplaceAll(normalized, replacement.from, replacement.to)
+ }
+ return normalized
+}
+
+func normalizeKnownOpenAICodexModel(model string) string {
+ normalized := canonicalizeOpenAIModelAliasSpelling(model)
+ if normalized == "" {
+ return ""
+ }
+
+ if mapped := getNormalizedCodexModel(normalized); mapped != "" {
+ return mapped
+ }
+ if strings.HasSuffix(normalized, "-openai-compact") {
+ if mapped := getNormalizedCodexModel(strings.TrimSuffix(normalized, "-openai-compact")); mapped != "" {
+ return mapped
+ }
+ }
+
+ switch {
+ case strings.Contains(normalized, "gpt-5.5"):
+ return "gpt-5.5"
+ case strings.Contains(normalized, "gpt-5.4-mini"):
+ return "gpt-5.4-mini"
+ case strings.Contains(normalized, "gpt-5.4-nano"):
+ return "gpt-5.4-nano"
+ case strings.Contains(normalized, "gpt-5.4"):
+ return "gpt-5.4"
+ case strings.Contains(normalized, "gpt-5.2"):
+ return "gpt-5.2"
+ case strings.Contains(normalized, "gpt-5.3-codex-spark"):
+ return "gpt-5.3-codex-spark"
+ case strings.Contains(normalized, "gpt-5.3-codex"):
+ return "gpt-5.3-codex"
+ case strings.Contains(normalized, "gpt-5.3"):
+ return "gpt-5.3-codex"
+ case strings.Contains(normalized, "codex"):
+ return "gpt-5.3-codex"
+ case strings.Contains(normalized, "gpt-5"):
+ return "gpt-5.4"
+ default:
+ return ""
+ }
+}
+
+func appendUsageBillingModelCandidate(candidates []string, seen map[string]struct{}, model string) []string {
+ trimmed := strings.TrimSpace(model)
+ if trimmed == "" {
+ return candidates
+ }
+ add := func(candidate string) {
+ candidate = strings.TrimSpace(candidate)
+ if candidate == "" {
+ return
+ }
+ key := strings.ToLower(candidate)
+ if _, ok := seen[key]; ok {
+ return
+ }
+ seen[key] = struct{}{}
+ candidates = append(candidates, candidate)
+ }
+
+ add(trimmed)
+ if canonical := canonicalizeOpenAIModelAliasSpelling(trimmed); canonical != "" {
+ add(canonical)
+ }
+ if normalized := normalizeKnownOpenAICodexModel(trimmed); normalized != "" {
+ add(normalized)
+ }
+ return candidates
+}
+
+func usageBillingModelCandidates(primary string, alternates ...string) []string {
+ seen := make(map[string]struct{}, 1+len(alternates))
+ candidates := appendUsageBillingModelCandidate(nil, seen, primary)
+ for _, alternate := range alternates {
+ candidates = appendUsageBillingModelCandidate(candidates, seen, alternate)
+ }
+ return candidates
+}
+
+func firstUsageBillingModel(candidates []string) string {
+ for _, candidate := range candidates {
+ if trimmed := strings.TrimSpace(candidate); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index 5c3e1ae0..f087ac32 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -94,6 +94,15 @@ func TestResolveOpenAIForwardModel(t *testing.T) {
defaultMappedModel: "gpt-5.4",
expectedModel: "gpt-5.5",
},
+ {
+ name: "preserves compact-spelled gpt5.5 instead of group default",
+ account: &Account{
+ Credentials: map[string]any{},
+ },
+ requestedModel: "gpt5.5",
+ defaultMappedModel: "gpt-5.4",
+ expectedModel: "gpt5.5",
+ },
{
name: "preserves openai namespaced gpt-5.5 instead of group default",
account: &Account{
diff --git a/backend/internal/service/openai_sse_data.go b/backend/internal/service/openai_sse_data.go
new file mode 100644
index 00000000..61b813b6
--- /dev/null
+++ b/backend/internal/service/openai_sse_data.go
@@ -0,0 +1,70 @@
+package service
+
+import (
+ "strings"
+
+ "github.com/tidwall/gjson"
+)
+
+type openAISSEDataAccumulator struct {
+ lines []string
+}
+
+func (a *openAISSEDataAccumulator) AddLine(line string, fn func([]byte)) {
+ if fn == nil {
+ return
+ }
+ trimmedLine := strings.TrimRight(line, "\r\n")
+ if data, ok := extractOpenAISSEDataLine(trimmedLine); ok {
+ a.lines = append(a.lines, data)
+ return
+ }
+ if strings.TrimSpace(trimmedLine) == "" {
+ a.Flush(fn)
+ }
+}
+
+func (a *openAISSEDataAccumulator) Flush(fn func([]byte)) {
+ if fn == nil || len(a.lines) == 0 {
+ return
+ }
+ emitOpenAISSEDataPayloads(a.lines, fn)
+ a.lines = a.lines[:0]
+}
+
+func forEachOpenAISSEDataPayload(body string, fn func([]byte)) {
+ if fn == nil || strings.TrimSpace(body) == "" {
+ return
+ }
+ var acc openAISSEDataAccumulator
+ for _, line := range strings.Split(body, "\n") {
+ acc.AddLine(line, fn)
+ }
+ acc.Flush(fn)
+}
+
+func emitOpenAISSEDataPayloads(lines []string, fn func([]byte)) {
+ if fn == nil || len(lines) == 0 {
+ return
+ }
+ if len(lines) == 1 {
+ emitOpenAISSEDataPayload(lines[0], fn)
+ return
+ }
+ joined := strings.Join(lines, "\n")
+ if gjson.Valid(joined) {
+ emitOpenAISSEDataPayload(joined, fn)
+ return
+ }
+ for _, line := range lines {
+ emitOpenAISSEDataPayload(line, fn)
+ }
+}
+
+func emitOpenAISSEDataPayload(data string, fn func([]byte)) {
+ data = strings.TrimSpace(data)
+ if data == "" || data == "[DONE]" {
+ return
+ }
+ fn([]byte(data))
+}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 201073e0..cb045ae7 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -1990,6 +1990,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
usage := &OpenAIUsage{}
+ imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
responseID := ""
var finalResponse []byte
@@ -2171,6 +2172,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(message, usage)
}
+ imageCounter.AddSSEData(message)
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
@@ -2343,6 +2345,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
Usage: *usage,
Model: originalModel,
UpstreamModel: mappedModel,
+ ImageCount: imageCounter.Count(),
ServiceTier: extractOpenAIServiceTier(reqBody),
ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
Stream: reqStream,
@@ -2449,6 +2452,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey string
previousResponseID string
originalModel string
+ imageBillingModel string
+ imageSizeTier string
payloadBytes int
}
@@ -2546,6 +2551,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
normalized = next
}
+ imageIntent := IsImageGenerationIntent(openAIResponsesEndpoint, originalModel, normalized)
+ if imageIntent && !GroupAllowsImageGeneration(apiKeyGroup(getAPIKeyFromContext(c))) {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, ImageGenerationPermissionMessage(), nil)
+ }
+ imageBillingModel := ""
+ imageSizeTier := ""
+ if imageIntent {
+ var imageCfgErr error
+ imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
+ if imageCfgErr != nil {
+ return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
+ }
+ }
// Apply OpenAI Fast Policy on the response.create frame using the same
// evaluator/normalize/scope rules as the HTTP entrypoints. This is the
@@ -2591,6 +2609,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
promptCacheKey: promptCacheKey,
previousResponseID: previousResponseID,
originalModel: originalModel,
+ imageBillingModel: imageBillingModel,
+ imageSizeTier: imageSizeTier,
payloadBytes: len(normalized),
}, nil
}
@@ -2792,7 +2812,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil
}
- sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
+ sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
@@ -2817,6 +2837,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
responseID := ""
usage := OpenAIUsage{}
+ imageCounter := newOpenAIImageOutputCounter()
var firstTokenMs *int
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
@@ -2938,6 +2959,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
}
+ imageCounter.AddSSEData(upstreamMessage)
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
@@ -2997,7 +3019,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
clientDisconnected,
)
}
- return &OpenAIForwardResult{
+ imageCount := imageCounter.Count()
+ result := &OpenAIForwardResult{
RequestID: responseID,
Usage: usage,
Model: originalModel,
@@ -3009,13 +3032,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
ResponseHeaders: lease.HandshakeHeaders(),
Duration: time.Since(turnStart),
FirstTokenMs: firstTokenMs,
- }, nil
+ }
+ if imageCount > 0 {
+ result.ImageCount = imageCount
+ result.ImageSize = imageSizeTier
+ result.BillingModel = imageBillingModel
+ }
+ return result, nil
}
}
}
currentPayload := firstPayload.payloadRaw
currentOriginalModel := firstPayload.originalModel
+ currentImageBillingModel := firstPayload.imageBillingModel
+ currentImageSizeTier := firstPayload.imageSizeTier
currentPayloadBytes := firstPayload.payloadBytes
isStrictAffinityTurn := func(payload []byte) bool {
if !storeDisabled {
@@ -3460,7 +3491,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
- result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
+ result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
if relayErr != nil {
lastTurnClean = false
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
@@ -3582,6 +3613,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
currentPayload = nextPayload.payloadRaw
currentOriginalModel = nextPayload.originalModel
+ currentImageBillingModel = nextPayload.imageBillingModel
+ currentImageSizeTier = nextPayload.imageSizeTier
currentPayloadBytes = nextPayload.payloadBytes
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
if !storeDisabled {
diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go
index 7a76c385..cd816533 100644
--- a/backend/internal/service/openai_ws_forwarder_success_test.go
+++ b/backend/internal/service/openai_ws_forwarder_success_test.go
@@ -171,6 +171,127 @@ func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
}
+func TestOpenAIGatewayService_Forward_WSv2_ImageGenerationCountsOutputs(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Errorf("upgrade websocket failed: %v", err)
+ return
+ }
+ defer func() {
+ _ = conn.Close()
+ }()
+
+ var request map[string]any
+ if err := conn.ReadJSON(&request); err != nil {
+ t.Errorf("read ws request failed: %v", err)
+ return
+ }
+
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.output_item.done",
+ "item": map[string]any{
+ "id": "ig_ws_1",
+ "type": "image_generation_call",
+ "result": "final-image",
+ },
+ }); err != nil {
+ t.Errorf("write response.output_item.done failed: %v", err)
+ return
+ }
+ if err := conn.WriteJSON(map[string]any{
+ "type": "response.completed",
+ "response": map[string]any{
+ "id": "resp_ws_image_1",
+ "model": "gpt-5.4",
+ "output": []any{
+ map[string]any{
+ "id": "ig_ws_1",
+ "type": "image_generation_call",
+ "result": "final-image",
+ },
+ },
+ "usage": map[string]any{
+ "input_tokens": 9,
+ "output_tokens": 4,
+ },
+ },
+ }); err != nil {
+ t.Errorf("write response.completed failed: %v", err)
+ return
+ }
+ }))
+ defer wsServer.Close()
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ groupID := int64(1010)
+ c.Set("api_key", &APIKey{
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ AllowImageGeneration: true,
+ },
+ })
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
+ cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
+ cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
+ cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
+ cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
+ cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
+ cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ account := &Account{
+ ID: 10,
+ Name: "openai-ws-image",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": wsServer.URL,
+ },
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+
+ body := []byte(`{"model":"gpt-5.4","stream":false,"input":"draw","tools":[{"type":"image_generation","model":"gpt-image-2","size":"1024x1024"}],"tool_choice":{"type":"image_generation"}}`)
+ result, err := svc.Forward(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "resp_ws_image_1", result.RequestID)
+ require.Equal(t, 1, result.ImageCount)
+ require.Equal(t, "1K", result.ImageSize)
+ require.Equal(t, "gpt-image-2", result.BillingModel)
+ require.Equal(t, 9, result.Usage.InputTokens)
+ require.Equal(t, 4, result.Usage.OutputTokens)
+ require.True(t, result.OpenAIWSMode)
+ require.Equal(t, "resp_ws_image_1", gjson.GetBytes(rec.Body.Bytes(), "id").String())
+}
+
func requestToJSONString(payload map[string]any) string {
if len(payload) == 0 {
return "{}"
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index 91a02901..8a033710 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -625,6 +625,9 @@ func normalizeModelNameForPricing(model string) string {
}
model = strings.TrimLeft(model, "/")
+ if canonical := canonicalizeOpenAIModelAliasSpelling(model); canonical != "" {
+ return canonical
+ }
return model
}
diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go
index e2bd7cf3..3c3e2c5b 100644
--- a/backend/internal/service/pricing_service_test.go
+++ b/backend/internal/service/pricing_service_test.go
@@ -98,6 +98,19 @@ func TestGetModelPricing_Gpt54UsesStaticFallbackWhenRemoteMissing(t *testing.T)
require.InDelta(t, 1.5, got.LongContextOutputCostMultiplier, 1e-12)
}
+func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) {
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{
+ "gpt-5.1-codex": {InputCostPerToken: 1.25e-6},
+ },
+ }
+
+ got := svc.GetModelPricing("openai/gpt5.5")
+ require.NotNil(t, got)
+ require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12)
+ require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
+}
+
func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
diff --git a/backend/migrations/134_image_generation_group_controls.sql b/backend/migrations/134_image_generation_group_controls.sql
new file mode 100644
index 00000000..37941c00
--- /dev/null
+++ b/backend/migrations/134_image_generation_group_controls.sql
@@ -0,0 +1,26 @@
+-- 生图能力与图片倍率模式控制
+-- 兼容性原则:
+-- 1. 不改写现有 image_price_1k/2k/4k,避免改变已配置分组的最终图片价格。
+-- 2. 现有 openai/gemini/antigravity 分组默认保持可生图,避免升级后中断已有图片业务。
+-- 3. 现有分组默认共享当前有效分组倍率,保持历史扣费公式。
+
+ALTER TABLE groups
+ ADD COLUMN IF NOT EXISTS allow_image_generation BOOLEAN NOT NULL DEFAULT false;
+
+ALTER TABLE groups
+ ADD COLUMN IF NOT EXISTS image_rate_independent BOOLEAN NOT NULL DEFAULT false;
+
+ALTER TABLE groups
+ ADD COLUMN IF NOT EXISTS image_rate_multiplier DECIMAL(10,4) NOT NULL DEFAULT 1.0;
+
+UPDATE groups
+SET allow_image_generation = true
+WHERE platform IN ('openai', 'gemini', 'antigravity');
+
+UPDATE groups
+SET image_rate_independent = false,
+ image_rate_multiplier = 1.0;
+
+COMMENT ON COLUMN groups.allow_image_generation IS '是否允许该分组使用图片生成能力';
+COMMENT ON COLUMN groups.image_rate_independent IS '图片生成是否使用独立倍率;false 表示共享分组有效倍率';
+COMMENT ON COLUMN groups.image_rate_multiplier IS '图片生成独立倍率,仅 image_rate_independent=true 时生效';
diff --git a/deploy/.env.example b/deploy/.env.example
index e1eb8256..28205f7c 100644
--- a/deploy/.env.example
+++ b/deploy/.env.example
@@ -285,6 +285,25 @@ GATEWAY_SCHEDULING_OUTBOX_BACKLOG_REBUILD_ROWS=10000
# 全量重建周期(秒)
GATEWAY_SCHEDULING_FULL_REBUILD_INTERVAL_SECONDS=300
+# -----------------------------------------------------------------------------
+# Image Generation Stream & Concurrency (Optional)
+# 图片生成流式与并发隔离配置(可选)
+# -----------------------------------------------------------------------------
+# 图片流式上游数据间隔超时(秒)。0 表示禁用;非 0 时必须为 60-1800。
+GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=900
+# 图片流式 keepalive 间隔(秒)。0 表示禁用;非 0 时必须为 5-60。
+GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=10
+# 是否启用进程级图片生成并发限制。默认 false,保持历史行为。
+GATEWAY_IMAGE_CONCURRENCY_ENABLED=false
+# 当前进程允许同时处理的图片生成请求数。0 表示不限制。
+GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=0
+# 图片并发超限策略:reject 直接返回 429;wait 等待空闲槽位。
+GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=reject
+# wait 模式下等待空闲图片槽位的最长时间(秒)。
+GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=30
+# wait 模式下当前进程允许排队等待的最大图片请求数。0 表示不允许等待队列。
+GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=100
+
# -----------------------------------------------------------------------------
# Dashboard Aggregation (Optional)
# -----------------------------------------------------------------------------
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index dfc363b5..1670699b 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -340,6 +340,30 @@ gateway:
# Stream keepalive interval (seconds), 0=disable
# 流式 keepalive 间隔(秒),0=禁用
stream_keepalive_interval: 10
+ # Image stream data interval timeout (seconds), 0=disable; independent from ordinary text streams
+ # 图片流数据间隔超时(秒),0=禁用;独立于普通文本流式
+ image_stream_data_interval_timeout: 900
+ # Image stream keepalive interval (seconds), 0=disable; independent from ordinary text streams
+ # 图片流式 keepalive 间隔(秒),0=禁用;独立于普通文本流式
+ image_stream_keepalive_interval: 10
+ # Image generation independent concurrency limiter (process-local, default disabled)
+ # 图片生成独立并发限制(进程级,默认关闭;多实例总上限约为实例数×该值)
+ image_concurrency:
+ # Enable image-only concurrency protection; false keeps existing behavior unchanged
+ # 是否启用图片独立并发保护;false 保持现有行为不变
+ enabled: false
+ # Max concurrent image generation requests in this process, 0=unlimited
+ # 当前进程允许同时处理的图片生成请求数,0=不限制
+ max_concurrent_requests: 0
+ # Overflow mode when the image concurrency limit is full: reject/wait
+ # 图片并发满时的处理方式:reject=立即拒绝,wait=等待槽位
+ overflow_mode: "reject"
+ # Wait timeout for overflow_mode=wait (seconds), 0=do not wait
+ # wait 模式等待图片并发槽位的超时时间(秒),0=不等待
+ wait_timeout_seconds: 30
+ # Max image requests waiting in this process when overflow_mode=wait, 0=unlimited
+ # wait 模式当前进程允许排队等待的图片请求数,0=不限制
+ max_waiting_requests: 100
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB)
max_line_size: 41943040
diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml
index 7793e424..b7a805b5 100644
--- a/deploy/docker-compose.dev.yml
+++ b/deploy/docker-compose.dev.yml
@@ -40,6 +40,13 @@ services:
- JWT_SECRET=${JWT_SECRET:-}
- TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-}
- TZ=${TZ:-Asia/Shanghai}
+ - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
+ - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
+ - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
+ - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
+ - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy
diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml
index 5aea78fb..51a80227 100644
--- a/deploy/docker-compose.local.yml
+++ b/deploy/docker-compose.local.yml
@@ -146,6 +146,17 @@ services:
# Proxy for accessing GitHub (online updates + pricing data)
# Examples: http://host:port, socks5://host:port
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
+
+ # =======================================================================
+ # Image Generation Stream & Concurrency
+ # =======================================================================
+ - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
+ - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
+ - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
+ - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
+ - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy
diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml
index df0ccfcc..438d0a8a 100644
--- a/deploy/docker-compose.standalone.yml
+++ b/deploy/docker-compose.standalone.yml
@@ -93,6 +93,17 @@ services:
# SECURITY: This repo does not embed third-party client_secret.
- GEMINI_CLI_OAUTH_CLIENT_SECRET=${GEMINI_CLI_OAUTH_CLIENT_SECRET:-}
- ANTIGRAVITY_OAUTH_CLIENT_SECRET=${ANTIGRAVITY_OAUTH_CLIENT_SECRET:-}
+
+ # =======================================================================
+ # Image Generation Stream & Concurrency
+ # =======================================================================
+ - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
+ - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
+ - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
+ - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
+ - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
healthcheck:
test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
interval: 30s
diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml
index 3a714260..1d639ea4 100644
--- a/deploy/docker-compose.yml
+++ b/deploy/docker-compose.yml
@@ -142,6 +142,17 @@ services:
# Proxy for accessing GitHub (online updates + pricing data)
# Examples: http://host:port, socks5://host:port
- UPDATE_PROXY_URL=${UPDATE_PROXY_URL:-}
+
+ # =======================================================================
+ # Image Generation Stream & Concurrency
+ # =======================================================================
+ - GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT=${GATEWAY_IMAGE_STREAM_DATA_INTERVAL_TIMEOUT:-900}
+ - GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL=${GATEWAY_IMAGE_STREAM_KEEPALIVE_INTERVAL:-10}
+ - GATEWAY_IMAGE_CONCURRENCY_ENABLED=${GATEWAY_IMAGE_CONCURRENCY_ENABLED:-false}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_CONCURRENT_REQUESTS:-0}
+ - GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE=${GATEWAY_IMAGE_CONCURRENCY_OVERFLOW_MODE:-reject}
+ - GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS=${GATEWAY_IMAGE_CONCURRENCY_WAIT_TIMEOUT_SECONDS:-30}
+ - GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS=${GATEWAY_IMAGE_CONCURRENCY_MAX_WAITING_REQUESTS:-100}
depends_on:
postgres:
condition: service_healthy
diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue
index adcb3cc6..629e6aa2 100644
--- a/frontend/src/components/admin/usage/UsageTable.vue
+++ b/frontend/src/components/admin/usage/UsageTable.vue
@@ -291,9 +291,23 @@
+
+
{{ t("admin.groups.imagePricing.description") }}
++ {{ t("admin.groups.imagePricing.modeHint") }} +
+{{ t("admin.groups.imagePricing.description") }}
++ {{ t("admin.groups.imagePricing.modeHint") }} +
+