diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index 05dd1d1a..fd0c7a41 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -32,7 +32,7 @@ jobs:
working-directory: backend
run: |
go install github.com/securego/gosec/v2/cmd/gosec@latest
- gosec -severity high -confidence high ./...
+ gosec -conf .gosec.json -severity high -confidence high ./...
frontend-security:
runs-on: ubuntu-latest
diff --git a/backend/.gosec.json b/backend/.gosec.json
new file mode 100644
index 00000000..b34e140c
--- /dev/null
+++ b/backend/.gosec.json
@@ -0,0 +1,5 @@
+{
+ "global": {
+ "exclude": "G704"
+ }
+}
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 62d7b53f..f788a87d 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.74.9
+0.1.83.2
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index be17fb01..a0f8807a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig)
- soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider)
+ soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig)
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index d7d31f08..3c8d4870 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -669,6 +669,7 @@ var (
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
+ {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
{Name: "account_id", Type: field.TypeInt64},
@@ -684,31 +685,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[27]},
+ Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[28]},
+ Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[30]},
+ Columns: []*schema.Column{UsageLogsColumns[31]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[31]},
+ Columns: []*schema.Column{UsageLogsColumns[32]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -717,32 +718,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[30]},
+ Columns: []*schema.Column{UsageLogsColumns[31]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[27]},
+ Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[28]},
+ Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[31]},
+ Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[26]},
+ Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_model",
@@ -757,12 +758,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]},
+ Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]},
+ Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]},
},
},
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 2e32d228..678e98c4 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -15980,6 +15980,7 @@ type UsageLogMutation struct {
addimage_count *int
image_size *string
media_type *string
+ cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
user *int64
@@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() {
delete(m.clearedFields, usagelog.FieldMediaType)
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
+ m.cache_ttl_overridden = &b
+}
+
+// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation.
+func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) {
+ v := m.cache_ttl_overridden
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err)
+ }
+ return oldValue.CacheTTLOverridden, nil
+}
+
+// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field.
+func (m *UsageLogMutation) ResetCacheTTLOverridden() {
+ m.cache_ttl_overridden = nil
+}
+
// SetCreatedAt sets the "created_at" field.
func (m *UsageLogMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
@@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 31)
+ fields := make([]string, 0, 32)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.media_type != nil {
fields = append(fields, usagelog.FieldMediaType)
}
+ if m.cache_ttl_overridden != nil {
+ fields = append(fields, usagelog.FieldCacheTTLOverridden)
+ }
if m.created_at != nil {
fields = append(fields, usagelog.FieldCreatedAt)
}
@@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageSize()
case usagelog.FieldMediaType:
return m.MediaType()
+ case usagelog.FieldCacheTTLOverridden:
+ return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
return m.CreatedAt()
}
@@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageSize(ctx)
case usagelog.FieldMediaType:
return m.OldMediaType(ctx)
+ case usagelog.FieldCacheTTLOverridden:
+ return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
return m.OldCreatedAt(ctx)
}
@@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetMediaType(v)
return nil
+ case usagelog.FieldCacheTTLOverridden:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCacheTTLOverridden(v)
+ return nil
case usagelog.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
@@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldMediaType:
m.ResetMediaType()
return nil
+ case usagelog.FieldCacheTTLOverridden:
+ m.ResetCacheTTLOverridden()
+ return nil
case usagelog.FieldCreatedAt:
m.ResetCreatedAt()
return nil
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 8da5f84c..5e980be0 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -821,8 +821,12 @@ func init() {
usagelogDescMediaType := usagelogFields[29].Descriptor()
// usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
+ // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
+ usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor()
+ // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
+ usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[30].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[31].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index 602f23f6..ffcae840 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field {
Optional().
Nillable(),
+ // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
+ field.Bool("cache_ttl_overridden").
+ Default(false),
+
// 时间戳(只有 created_at,日志不可修改)
field.Time("created_at").
Default(time.Now).
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index 63a14197..f6968d0d 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -82,6 +82,8 @@ type UsageLog struct {
ImageSize *string `json:"image_size,omitempty"`
// MediaType holds the value of the "media_type" field.
MediaType *string `json:"media_type,omitempty"`
+ // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
+ CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case usagelog.FieldStream:
+ case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
@@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.MediaType = new(string)
*_m.MediaType = value.String
}
+ case usagelog.FieldCacheTTLOverridden:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
+ } else if value.Valid {
+ _m.CacheTTLOverridden = value.Bool
+ }
case usagelog.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
@@ -562,6 +570,9 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ builder.WriteString("cache_ttl_overridden=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
+ builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteByte(')')
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index 3ea5d054..ba97b843 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -74,6 +74,8 @@ const (
FieldImageSize = "image_size"
// FieldMediaType holds the string denoting the media_type field in the database.
FieldMediaType = "media_type"
+ // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
+ FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// EdgeUser holds the string denoting the user edge name in mutations.
@@ -158,6 +160,7 @@ var Columns = []string{
FieldImageCount,
FieldImageSize,
FieldMediaType,
+ FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -216,6 +219,8 @@ var (
ImageSizeValidator func(string) error
// MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
MediaTypeValidator func(string) error
+ // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
+ DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
)
@@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMediaType, opts...).ToFunc()
}
+// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
+func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
+}
+
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index 0a33dba2..af960335 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
}
+// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
+func CacheTTLOverridden(v bool) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
+}
+
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
@@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
}
+// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
+func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
+}
+
+// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field.
+func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index 668a0ede..e0285a5e 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
return _c
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
+ _c.mutation.SetCacheTTLOverridden(v)
+ return _c
+}
+
+// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate {
+ if v != nil {
+ _c.SetCacheTTLOverridden(*v)
+ }
+ return _c
+}
+
// SetCreatedAt sets the "created_at" field.
func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate {
_c.mutation.SetCreatedAt(v)
@@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() {
v := usagelog.DefaultImageCount
_c.mutation.SetImageCount(v)
}
+ if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
+ v := usagelog.DefaultCacheTTLOverridden
+ _c.mutation.SetCacheTTLOverridden(v)
+ }
if _, ok := _c.mutation.CreatedAt(); !ok {
v := usagelog.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v)
@@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
}
}
+ if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
+ return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
+ }
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)}
}
@@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
_node.MediaType = &value
}
+ if value, ok := _c.mutation.CacheTTLOverridden(); ok {
+ _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
+ _node.CacheTTLOverridden = value
+ }
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
@@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
return u
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
+ u.Set(usagelog.FieldCacheTTLOverridden, v)
+ return u
+}
+
+// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldCacheTTLOverridden)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
})
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetCacheTTLOverridden(v)
+ })
+}
+
+// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateCacheTTLOverridden()
+ })
+}
+
// Exec executes the query.
func (u *UsageLogUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
})
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetCacheTTLOverridden(v)
+ })
+}
+
+// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateCacheTTLOverridden()
+ })
+}
+
// Exec executes the query.
func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index 22f2613f..b46e5b56 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
return _u
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
+ _u.mutation.SetCacheTTLOverridden(v)
+ return _u
+}
+
+// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate {
+ if v != nil {
+ _u.SetCacheTTLOverridden(*v)
+ }
+ return _u
+}
+
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate {
return _u.SetUserID(v.ID)
@@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
+ if value, ok := _u.mutation.CacheTTLOverridden(); ok {
+ _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
+ }
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
return _u
}
+// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
+func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
+ _u.mutation.SetCacheTTLOverridden(v)
+ return _u
+}
+
+// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetCacheTTLOverridden(*v)
+ }
+ return _u
+}
+
// SetUser sets the "user" edge to the User entity.
func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne {
return _u.SetUserID(v.ID)
@@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.MediaTypeCleared() {
_spec.ClearField(usagelog.FieldMediaType, field.TypeString)
}
+ if value, ok := _u.mutation.CacheTTLOverridden(); ok {
+ _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
+ }
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index b9f31ba9..330ae0c1 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -162,6 +162,8 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
+ // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
+ SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -269,17 +271,30 @@ type SoraConfig struct {
// SoraClientConfig 直连 Sora 客户端配置
type SoraClientConfig struct {
- BaseURL string `mapstructure:"base_url"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- MaxRetries int `mapstructure:"max_retries"`
- PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
- MaxPollAttempts int `mapstructure:"max_poll_attempts"`
- RecentTaskLimit int `mapstructure:"recent_task_limit"`
- RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
- Debug bool `mapstructure:"debug"`
- Headers map[string]string `mapstructure:"headers"`
- UserAgent string `mapstructure:"user_agent"`
- DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
+ BaseURL string `mapstructure:"base_url"`
+ TimeoutSeconds int `mapstructure:"timeout_seconds"`
+ MaxRetries int `mapstructure:"max_retries"`
+ CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
+ PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
+ MaxPollAttempts int `mapstructure:"max_poll_attempts"`
+ RecentTaskLimit int `mapstructure:"recent_task_limit"`
+ RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
+ Debug bool `mapstructure:"debug"`
+ UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
+ Headers map[string]string `mapstructure:"headers"`
+ UserAgent string `mapstructure:"user_agent"`
+ DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
+ CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
+}
+
+// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
+type SoraCurlCFFISidecarConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ BaseURL string `mapstructure:"base_url"`
+ Impersonate string `mapstructure:"impersonate"`
+ TimeoutSeconds int `mapstructure:"timeout_seconds"`
+ SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
+ SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
}
// SoraStorageConfig 媒体存储配置
@@ -1111,14 +1126,22 @@ func setDefaults() {
viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
viper.SetDefault("sora.client.timeout_seconds", 120)
viper.SetDefault("sora.client.max_retries", 3)
+ viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
viper.SetDefault("sora.client.poll_interval_seconds", 2)
viper.SetDefault("sora.client.max_poll_attempts", 600)
viper.SetDefault("sora.client.recent_task_limit", 50)
viper.SetDefault("sora.client.recent_task_limit_max", 200)
viper.SetDefault("sora.client.debug", false)
+ viper.SetDefault("sora.client.use_openai_token_provider", false)
viper.SetDefault("sora.client.headers", map[string]string{})
viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
viper.SetDefault("sora.client.disable_tls_fingerprint", false)
+ viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
+ viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
+ viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
+ viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
+ viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
+ viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
viper.SetDefault("sora.storage.type", "local")
viper.SetDefault("sora.storage.local_path", "")
@@ -1137,6 +1160,7 @@ func setDefaults() {
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
+ viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
@@ -1505,6 +1529,9 @@ func (c *Config) Validate() error {
if c.Sora.Client.MaxRetries < 0 {
return fmt.Errorf("sora.client.max_retries must be non-negative")
}
+ if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
+ return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
+ }
if c.Sora.Client.PollIntervalSeconds < 0 {
return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
}
@@ -1521,6 +1548,18 @@ func (c *Config) Validate() error {
c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
}
+ if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
+ return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
+ }
+ if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
+ return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
+ }
+ if !c.Sora.Client.CurlCFFISidecar.Enabled {
+ return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
+ }
+ if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
+ return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
+ }
if c.Sora.Storage.MaxConcurrentDownloads < 0 {
return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
}
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index a3c65c41..dcc60879 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) {
})
}
}
+
+func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
+ t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
+ }
+ if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
+ t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
+ }
+ if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
+ t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
+ }
+ if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
+ t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
+ }
+ if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
+ t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
+ }
+ if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
+ t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
+ }
+}
+
+func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.Sora.Client.CurlCFFISidecar.Enabled = false
+ err = cfg.Validate()
+ if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
+ t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
+ }
+}
+
+func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
+ err = cfg.Validate()
+ if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
+ t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
+ }
+}
+
+func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
+ err = cfg.Validate()
+ if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
+ t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
+ }
+}
+
+func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
+ err = cfg.Validate()
+ if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
+ t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
+ }
+}
diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go
index b5d1dd0a..34397696 100644
--- a/backend/internal/handler/admin/account_data.go
+++ b/backend/internal/handler/admin/account_data.go
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap
var out []service.Account
for {
- items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
+ items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
if err != nil {
return nil, err
}
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 4a9185eb..1aa0cf2b 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) {
search = search[:100]
}
- accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
+ var groupID int64
+ if groupIDStr := c.Query("group"); groupIDStr != "" {
+ groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
+ }
+
+ accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
+ // Handle Sora accounts
+ if account.Platform == service.PlatformSora {
+ response.Success(c, service.DefaultSoraModels(nil))
+ return
+ }
+
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
- allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
+ allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index 20a25222..aeb4097f 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
+ router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index cbbfe942..9f3dcf80 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil
}
-func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
+func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
+func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
+ return &service.ProxyQualityCheckResult{
+ ProxyID: id,
+ Score: 95,
+ Grade: "A",
+ Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
+ PassedCount: 5,
+ WarnCount: 0,
+ FailedCount: 0,
+ ChallengeCount: 0,
+ CheckedAt: time.Now().Unix(),
+ Items: []service.ProxyQualityCheckItem{
+ {Target: "base_connectivity", Status: "pass", Message: "ok"},
+ {Target: "openai", Status: "pass", HTTPStatus: 401},
+ {Target: "anthropic", Status: "pass", HTTPStatus: 401},
+ {Target: "gemini", Status: "pass", HTTPStatus: 200},
+ {Target: "sora", Status: "pass", HTTPStatus: 401},
+ },
+ }, nil
+}
+
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go
index ed86fea9..cf43f89e 100644
--- a/backend/internal/handler/admin/openai_oauth_handler.go
+++ b/backend/internal/handler/admin/openai_oauth_handler.go
@@ -2,6 +2,7 @@ package admin
import (
"strconv"
+ "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
+func oauthPlatformFromPath(c *gin.Context) string {
+ if strings.Contains(c.FullPath(), "/admin/sora/") {
+ return service.PlatformSora
+ }
+ return service.PlatformOpenAI
+}
+
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
+ State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
+ State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
- RefreshToken string `json:"refresh_token" binding:"required"`
+ RefreshToken string `json:"refresh_token"`
+ RT string `json:"rt"`
+ ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
+// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ refreshToken := strings.TrimSpace(req.RefreshToken)
+ if refreshToken == "" {
+ refreshToken = strings.TrimSpace(req.RT)
+ }
+ if refreshToken == "" {
+ response.BadRequest(c, "refresh_token is required")
+ return
+ }
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
- tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
+ tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
-// RefreshAccountToken refreshes token for a specific OpenAI account
+// ExchangeSoraSessionToken exchanges Sora session token to access token
+// POST /api/v1/admin/sora/st2at
+func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
+ var req struct {
+ SessionToken string `json:"session_token"`
+ ST string `json:"st"`
+ ProxyID *int64 `json:"proxy_id"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ sessionToken := strings.TrimSpace(req.SessionToken)
+ if sessionToken == "" {
+ sessionToken = strings.TrimSpace(req.ST)
+ }
+ if sessionToken == "" {
+ response.BadRequest(c, "session_token is required")
+ return
+ }
+
+ tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, tokenInfo)
+}
+
+// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
+// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
- // Ensure account is OpenAI platform
- if !account.IsOpenAI() {
- response.BadRequest(c, "Account is not an OpenAI account")
+ platform := oauthPlatformFromPath(c)
+ if account.Platform != platform {
+ response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
-// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
+// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
+// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
+ State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
+ State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
+ platform := oauthPlatformFromPath(c)
+
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
- name = "OpenAI OAuth Account"
+ if platform == service.PlatformSora {
+ name = "Sora OAuth Account"
+ } else {
+ name = "OpenAI OAuth Account"
+ }
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
- Platform: "openai",
+ Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index a6758f69..5a9cd7a0 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result)
}
+// CheckQuality handles checking proxy quality across common AI targets.
+// POST /api/v1/admin/proxies/:id/quality-check
+func (h *ProxyHandler) CheckQuality(c *gin.Context) {
+ proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid proxy ID")
+ return
+ }
+
+ result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, result)
+}
+
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 3c216d65..dbc7a8bc 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
enabled := true
out.EnableSessionIDMasking = &enabled
}
+ // 缓存 TTL 强制替换
+ if a.IsCacheTTLOverrideEnabled() {
+ enabled := true
+ out.CacheTTLOverrideEnabled = &enabled
+ target := a.GetCacheTTLOverrideTarget()
+ out.CacheTTLOverrideTarget = &target
+ }
}
return out
@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
CountryCode: p.CountryCode,
Region: p.Region,
City: p.City,
+ QualityStatus: p.QualityStatus,
+ QualityScore: p.QualityScore,
+ QualityGrade: p.QualityGrade,
+ QualitySummary: p.QualitySummary,
+ QualityChecked: p.QualityChecked,
}
}
@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ImageSize: l.ImageSize,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
+ CacheTTLOverridden: l.CacheTTLOverridden,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index daac42bd..f2605ffc 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -156,6 +156,11 @@ type Account struct {
// 从 extra 字段提取,方便前端显示和编辑
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
+ // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
+ CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
+ CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
CountryCode string `json:"country_code,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
+ QualityStatus string `json:"quality_status,omitempty"`
+ QualityScore *int `json:"quality_score,omitempty"`
+ QualityGrade string `json:"quality_grade,omitempty"`
+ QualitySummary string `json:"quality_summary,omitempty"`
+ QualityChecked *int64 `json:"quality_checked,omitempty"`
}
type ProxyAccountSummary struct {
@@ -280,6 +290,9 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
+ // Cache TTL Override 标记
+ CacheTTLOverridden bool `json:"cache_ttl_overridden"`
+
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index 80932899..b958a133 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -20,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
streamMode string
+ soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
}
@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
pingInterval := time.Duration(0)
maxAccountSwitches := 3
streamMode := "force"
+ soraTLSEnabled := true
signKey := ""
mediaRoot := "/app/data/sora"
if cfg != nil {
@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
streamMode = mode
}
+ soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
mediaRoot = root
@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
streamMode: strings.ToLower(streamMode),
+ soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
}
@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
+ var lastFailoverBody []byte
+ var lastFailoverHeaders http.Header
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
- h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
+ zap.Int("last_upstream_status", lastFailoverStatus),
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("last_upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
account := selection.Account
setOpsSelectedAccount(c, account.ID, account.Platform)
+ proxyBound := account.ProxyID != nil
+ proxyID := int64(0)
+ if account.ProxyID != nil {
+ proxyID = *account.ProxyID
+ }
+ tlsFingerprintEnabled := h.soraTLSEnabled
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
- reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ reqLog.Warn("sora.account_wait_counter_increment_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
} else if !canWait {
reqLog.Info("sora.account_wait_queue_full",
zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
&streamStarted,
)
if err != nil {
- reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ reqLog.Warn("sora.account_slot_acquire_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
- h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
+ lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
+ lastFailoverBody = failoverErr.ResponseBody
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.upstream_failover_exhausted", fields...)
+ h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
+ lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
+ lastFailoverBody = failoverErr.ResponseBody
switchCount++
- reqLog.Warn("sora.upstream_failover_switching",
+ upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
+ fields := []zap.Field{
zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.String("upstream_error_code", upstreamErrCode),
+ zap.String("upstream_error_message", upstreamErrMsg),
zap.Int("switch_count", switchCount),
zap.Int("max_switches", maxAccountSwitches),
- )
+ }
+ if rayID != "" {
+ fields = append(fields, zap.String("upstream_cf_ray", rayID))
+ }
+ if mitigated != "" {
+ fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
+ }
+ if contentType != "" {
+ fields = append(fields, zap.String("upstream_content_type", contentType))
+ }
+ reqLog.Warn("sora.upstream_failover_switching", fields...)
continue
}
- reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
+ reqLog.Error("sora.forward_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
+ zap.Error(err),
+ )
return
}
@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
}(result, account, userAgent, clientIP)
reqLog.Debug("sora.request_completed",
zap.Int64("account_id", account.ID),
+ zap.Int64("proxy_id", proxyID),
+ zap.Bool("proxy_bound", proxyBound),
+ zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
zap.Int("switch_count", switchCount),
)
return
@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
-func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
- status, errType, errMsg := h.mapUpstreamError(statusCode)
+func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
+ status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
-func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
+func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
+ if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
+ baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
+ return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
+ }
+
+ upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
+ if strings.EqualFold(upstreamCode, "cf_shield_429") {
+ baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
+ return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
+ }
+ if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
+ switch statusCode {
+ case 401, 403, 404, 500, 502, 503, 504:
+ return http.StatusBadGateway, "upstream_error", upstreamMessage
+ case 429:
+ return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
+ }
+ }
+
switch statusCode {
case 401:
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
+ case 404:
+ if strings.EqualFold(upstreamCode, "unsupported_country_code") {
+ return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
+ }
+ return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
}
}
+func cloneHTTPHeaders(headers http.Header) http.Header {
+ if headers == nil {
+ return nil
+ }
+ return headers.Clone()
+}
+
+func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
+ if headers != nil {
+ mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
+ contentType = strings.TrimSpace(headers.Get("content-type"))
+ if contentType == "" {
+ contentType = strings.TrimSpace(headers.Get("Content-Type"))
+ }
+ }
+ rayID = soraerror.ExtractCloudflareRayID(headers, body)
+ return rayID, mitigated, contentType
+}
+
+func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
+}
+
+func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
+ message = strings.TrimSpace(message)
+ if message == "" {
+ return false
+ }
+ if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
+ lower := strings.ToLower(message)
+ if strings.Contains(lower, "
Just a moment...`)
+
+ h := &SoraGatewayHandler{}
+ h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
+
+ lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
+ require.Len(t, lines, 2)
+ jsonStr := strings.TrimPrefix(lines[1], "data: ")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "upstream_error", errorObj["type"])
+ msg, _ := errorObj["message"].(string)
+ require.Contains(t, msg, "Cloudflare challenge")
+ require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
+}
+
+func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ headers := http.Header{}
+ headers.Set("cf-ray", "9d03b68c086027a1-SEA")
+ body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
+
+ h := &SoraGatewayHandler{}
+ h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
+
+ lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
+ require.Len(t, lines, 2)
+ jsonStr := strings.TrimPrefix(lines[1], "data: ")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "rate_limit_error", errorObj["type"])
+ msg, _ := errorObj["message"].(string)
+ require.Contains(t, msg, "Cloudflare shield")
+ require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
+}
+
+func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("cf-mitigated", "challenge")
+ headers.Set("content-type", "text/html")
+ body := []byte(``)
+
+ rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
+ require.Equal(t, "9cff2d62d83bb98d", rayID)
+ require.Equal(t, "challenge", mitigated)
+ require.Equal(t, "text/html", contentType)
+}
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index eecee11e..423ad925 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -10,6 +10,7 @@ const (
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
BetaTokenCounting = "token-counting-2024-11-01"
+ BetaContext1M = "context-1m-2025-08-07"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
@@ -77,6 +78,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
+ {
+ ID: "claude-sonnet-4-6",
+ Type: "model",
+ DisplayName: "Claude Sonnet 4.6",
+ CreatedAt: "2026-02-18T00:00:00Z",
+ },
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index bb120b57..e3b931be 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -17,6 +17,8 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
+ // OAuth Client ID for Sora mobile flow (aligned with sora2api)
+ SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 58b824c9..3f77a57e 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
- return r.ListWithFilters(ctx, params, "", "", "", "")
+ return r.ListWithFilters(ctx, params, "", "", "", "", 0)
}
-func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
+func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
q := r.client.Account.Query()
if platform != "" {
@@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
if search != "" {
q = q.Where(dbaccount.NameContainsFold(search))
}
+ if groupID > 0 {
+ q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID)))
+ }
total, err := q.Count(ctx)
if err != nil {
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index a054b6d6..4f9d0152 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
tt.setup(client)
- accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
+ accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
@@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
- accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
+ accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index 394d3a1a..088e7d7f 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/url"
+ "strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
+}
+
+func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ if strings.TrimSpace(clientID) != "" {
+ return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
+ }
+
+ clientIDs := []string{
+ openai.ClientID,
+ openai.SoraClientID,
+ }
+ seen := make(map[string]struct{}, len(clientIDs))
+ var lastErr error
+ for _, clientID := range clientIDs {
+ clientID = strings.TrimSpace(clientID)
+ if clientID == "" {
+ continue
+ }
+ if _, ok := seen[clientID]; ok {
+ continue
+ }
+ seen[clientID] = struct{}{}
+
+ tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
+ if err == nil {
+ return tokenResp, nil
+ }
+ lastErr = err
+ }
+ if lastErr != nil {
+ return nil, lastErr
+ }
+ return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
+}
+
+func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
- formData.Set("client_id", openai.ClientID)
+ formData.Set("client_id", clientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index f9df08c8..5938272a 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
+ var seenClientIDs []string
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ clientID := r.PostForm.Get("client_id")
+ seenClientIDs = append(seenClientIDs, clientID)
+ if clientID == openai.ClientID {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = io.WriteString(w, "invalid_grant")
+ return
+ }
+ if clientID == openai.SoraClientID {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
+ return
+ }
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+
+ resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
+ require.NoError(s.T(), err, "RefreshToken")
+ require.Equal(s.T(), "at-sora", resp.AccessToken)
+ require.Equal(s.T(), "rt-sora", resp.RefreshToken)
+ require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
+}
+
+func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
+ const customClientID = "custom-client-id"
+ var seenClientIDs []string
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ clientID := r.PostForm.Get("client_id")
+ seenClientIDs = append(seenClientIDs, clientID)
+ if clientID != customClientID {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
+ }))
+
+ resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
+ require.NoError(s.T(), err, "RefreshTokenWithClientID")
+ require.Equal(s.T(), "at-custom", resp.AccessToken)
+ require.Equal(s.T(), "rt-custom", resp.RefreshToken)
+ require.Equal(s.T(), []string{customClientID}, seenClientIDs)
+}
+
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 681b1664..0389a008 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
@@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
image_size,
media_type,
reasoning_effort,
+ cache_ttl_overridden,
created_at
) VALUES (
$1, $2, $3, $4, $5,
@@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
- $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
+ $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
imageSize,
mediaType,
reasoningEffort,
+ log.CacheTTLOverridden,
createdAt,
}
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
@@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
imageSize sql.NullString
mediaType sql.NullString
reasoningEffort sql.NullString
+ cacheTTLOverridden bool
createdAt time.Time
)
@@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&imageSize,
&mediaType,
&reasoningEffort,
+ &cacheTTLOverridden,
&createdAt,
); err != nil {
return nil, err
@@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
+ CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt,
}
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index c574219b..d87d97b5 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) {
"image_count": 0,
"image_size": null,
"media_type": null,
+ "cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
}
@@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination
return nil, nil, errors.New("not implemented")
}
-func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
+func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go
index 704b0907..03d5d025 100644
--- a/backend/internal/server/middleware/cors.go
+++ b/backend/internal/server/middleware/cors.go
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
allowedSet[origin] = struct{}{}
}
+ allowHeaders := []string{
+ "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
+ "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
+ }
+ // OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
+ openAIProperties := []string{
+ "lang", "package-version", "os", "arch", "retry-count", "runtime",
+ "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
+ }
+ for _, prop := range openAIProperties {
+ allowHeaders = append(allowHeaders, "x-stainless-"+prop)
+ }
+ allowHeadersValue := strings.Join(allowHeaders, ", ")
return func(c *gin.Context) {
origin := strings.TrimSpace(c.GetHeader("Origin"))
@@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
- c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
+ c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
}
-
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if originAllowed {
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 57d54a54..4b4d97c3 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
+ // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
+ registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
@@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ sora := admin.Group("/sora")
+ {
+ sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
+ sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
+ sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
+ sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
+ sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
+ sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
+ sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
+ }
+}
+
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
@@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
+ proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index 32f34e0c..69881e70 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -1,6 +1,8 @@
package routes
import (
+ "net/http"
+
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -41,16 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
- }
-
- // Sora Chat Completions
- soraGateway := r.Group("/v1")
- soraGateway.Use(soraBodyLimit)
- soraGateway.Use(clientRequestID)
- soraGateway.Use(opsErrorLogger)
- soraGateway.Use(gin.HandlerFunc(apiKeyAuth))
- {
- soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions)
+ // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口
+ gateway.POST("/chat/completions", func(c *gin.Context) {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.",
+ },
+ })
+ })
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 592c5139..bce3f98f 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool {
return false
}
+// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换
+// 仅适用于 Anthropic OAuth/SetupToken 类型账号
+// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h)
+func (a *Account) IsCacheTTLOverrideEnabled() bool {
+ if !a.IsAnthropicOAuthOrSetupToken() {
+ return false
+ }
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra["cache_ttl_override_enabled"]; ok {
+ if enabled, ok := v.(bool); ok {
+ return enabled
+ }
+ }
+ return false
+}
+
+// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型
+// 返回 "5m" 或 "1h",默认 "5m"
+func (a *Account) GetCacheTTLOverrideTarget() string {
+ if a.Extra == nil {
+ return "5m"
+ }
+ if v, ok := a.Extra["cache_ttl_override_target"]; ok {
+ if target, ok := v.(string); ok && (target == "5m" || target == "1h") {
+ return target
+ }
+ }
+ return "5m"
+}
+
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
// 返回 0 表示未启用
func (a *Account) GetWindowCostLimit() float64 {
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 3cddd2c7..b301049f 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -35,7 +35,7 @@ type AccountRepository interface {
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
- ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
+ ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go
index 414b3678..a466b68a 100644
--- a/backend/internal/service/account_service_delete_test.go
+++ b/backend/internal/service/account_service_delete_test.go
@@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination
panic("unexpected List call")
}
-func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
+func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index 093f7d4d..a507efb4 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -12,13 +12,17 @@ import (
"io"
"log"
"net/http"
+ "net/url"
"regexp"
"strings"
+ "sync"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -32,6 +36,10 @@ const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
+ soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
+ soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
+ soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
+ soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
@@ -39,6 +47,9 @@ type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
+ Status string `json:"status,omitempty"`
+ Code string `json:"code,omitempty"`
+ Data any `json:"data,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
@@ -50,8 +61,13 @@ type AccountTestService struct {
antigravityGatewayService *AntigravityGatewayService
httpUpstream HTTPUpstream
cfg *config.Config
+ soraTestGuardMu sync.Mutex
+ soraTestLastRun map[int64]time.Time
+ soraTestCooldown time.Duration
}
+const defaultSoraTestCooldown = 10 * time.Second
+
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -66,6 +82,8 @@ func NewAccountTestService(
antigravityGatewayService: antigravityGatewayService,
httpUpstream: httpUpstream,
cfg: cfg,
+ soraTestLastRun: make(map[int64]time.Time),
+ soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -467,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
+type soraProbeStep struct {
+ Name string `json:"name"`
+ Status string `json:"status"`
+ HTTPStatus int `json:"http_status,omitempty"`
+ ErrorCode string `json:"error_code,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
+type soraProbeSummary struct {
+ Status string `json:"status"`
+ Steps []soraProbeStep `json:"steps"`
+}
+
+type soraProbeRecorder struct {
+ steps []soraProbeStep
+}
+
+func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
+ r.steps = append(r.steps, soraProbeStep{
+ Name: name,
+ Status: status,
+ HTTPStatus: httpStatus,
+ ErrorCode: strings.TrimSpace(errorCode),
+ Message: strings.TrimSpace(message),
+ })
+}
+
+func (r *soraProbeRecorder) finalize() soraProbeSummary {
+ meSuccess := false
+ partial := false
+ for _, step := range r.steps {
+ if step.Name == "me" {
+ meSuccess = strings.EqualFold(step.Status, "success")
+ continue
+ }
+ if strings.EqualFold(step.Status, "failed") {
+ partial = true
+ }
+ }
+
+ status := "success"
+ if !meSuccess {
+ status = "failed"
+ } else if partial {
+ status = "partial_success"
+ }
+
+ return soraProbeSummary{
+ Status: status,
+ Steps: append([]soraProbeStep(nil), r.steps...),
+ }
+}
+
+func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
+ if rec == nil {
+ return
+ }
+ summary := rec.finalize()
+ code := ""
+ for _, step := range summary.Steps {
+ if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
+ code = step.ErrorCode
+ break
+ }
+ }
+ s.sendEvent(c, TestEvent{
+ Type: "sora_test_result",
+ Status: summary.Status,
+ Code: code,
+ Data: summary,
+ })
+}
+
+func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
+ if accountID <= 0 {
+ return 0, true
+ }
+ s.soraTestGuardMu.Lock()
+ defer s.soraTestGuardMu.Unlock()
+
+ if s.soraTestLastRun == nil {
+ s.soraTestLastRun = make(map[int64]time.Time)
+ }
+ cooldown := s.soraTestCooldown
+ if cooldown <= 0 {
+ cooldown = defaultSoraTestCooldown
+ }
+
+ now := time.Now()
+ if lastRun, ok := s.soraTestLastRun[accountID]; ok {
+ elapsed := now.Sub(lastRun)
+ if elapsed < cooldown {
+ return cooldown - elapsed, false
+ }
+ }
+ s.soraTestLastRun[accountID] = now
+ return 0, true
+}
+
+func ceilSeconds(d time.Duration) int {
+ if d <= 0 {
+ return 1
+ }
+ sec := int(d / time.Second)
+ if d%time.Second != 0 {
+ sec++
+ }
+ if sec < 1 {
+ sec = 1
+ }
+ return sec
+}
+
// testSoraAccountConnection 测试 Sora 账号的连接
// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token)
func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
ctx := c.Request.Context()
+ recorder := &soraProbeRecorder{}
authToken := account.GetCredential("access_token")
if authToken == "" {
+ recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "No access token available")
}
@@ -484,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
+ if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
+ msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
+ recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, msg)
+ }
+
// Send test_start event
s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
if err != nil {
+ recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, "Failed to create request")
}
@@ -496,15 +639,21 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
req.Header.Set("Accept", "application/json")
+ req.Header.Set("Accept-Language", "en-US,en;q=0.9")
+ req.Header.Set("Origin", "https://sora.chatgpt.com")
+ req.Header.Set("Referer", "https://sora.chatgpt.com/")
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
+ enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint()
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
if err != nil {
+ recorder.addStep("me", "failed", 0, "network_error", err.Error())
+ s.emitSoraProbeSummary(c, recorder)
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
@@ -512,8 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body)))
+ if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
+ recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
+ s.emitSoraProbeSummary(c, recorder)
+ s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
+ return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
+ }
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
+ switch {
+ case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
+ recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
+ case strings.EqualFold(upstreamCode, "unsupported_country_code"):
+ recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
+ case strings.TrimSpace(upstreamMessage) != "":
+ recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
+ default:
+ recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
+ s.emitSoraProbeSummary(c, recorder)
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
+ }
}
+ recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
// 解析 /me 响应,提取用户信息
var meResp map[string]any
@@ -531,10 +705,384 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *
s.sendEvent(c, TestEvent{Type: "content", Text: info})
}
+ // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
+ subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
+ if err == nil {
+ subReq.Header.Set("Authorization", "Bearer "+authToken)
+ subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+ subReq.Header.Set("Accept", "application/json")
+ subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
+ subReq.Header.Set("Origin", "https://sora.chatgpt.com")
+ subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
+
+ subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint)
+ if subErr != nil {
+ recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
+ } else {
+ subBody, _ := io.ReadAll(subResp.Body)
+ _ = subResp.Body.Close()
+ if subResp.StatusCode == http.StatusOK {
+ recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
+ if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: summary})
+ } else {
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
+ }
+ } else {
+ if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
+ recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
+ s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
+ } else {
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
+ recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
+ }
+ }
+ }
+ }
+
+ // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
+ s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder)
+
+ s.emitSoraProbeSummary(c, recorder)
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
+func (s *AccountTestService) testSora2Capabilities(
+ c *gin.Context,
+ ctx context.Context,
+ account *Account,
+ authToken string,
+ proxyURL string,
+ enableTLSFingerprint bool,
+ recorder *soraProbeRecorder,
+) {
+ inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
+ ctx,
+ account,
+ authToken,
+ soraInviteMineURL,
+ proxyURL,
+ enableTLSFingerprint,
+ )
+ if err != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
+ return
+ }
+
+ if inviteStatus == http.StatusUnauthorized {
+ bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
+ ctx,
+ account,
+ authToken,
+ soraBootstrapURL,
+ proxyURL,
+ enableTLSFingerprint,
+ )
+ if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
+ if recorder != nil {
+ recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
+ inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
+ ctx,
+ account,
+ authToken,
+ soraInviteMineURL,
+ proxyURL,
+ enableTLSFingerprint,
+ )
+ if err != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
+ return
+ }
+ } else if recorder != nil {
+ code := ""
+ msg := ""
+ if bootstrapErr != nil {
+ code = "network_error"
+ msg = bootstrapErr.Error()
+ }
+ recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
+ }
+ }
+
+ if inviteStatus != http.StatusOK {
+ if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
+ }
+ s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
+ return
+ }
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
+ return
+ }
+ if recorder != nil {
+ recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
+ }
+
+ if summary := parseSoraInviteSummary(inviteBody); summary != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: summary})
+ } else {
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
+ }
+
+ remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
+ ctx,
+ account,
+ authToken,
+ soraRemainingURL,
+ proxyURL,
+ enableTLSFingerprint,
+ )
+ if remainingErr != nil {
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
+ return
+ }
+ if remainingStatus != http.StatusOK {
+ if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
+ }
+ s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
+ s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
+ return
+ }
+ upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
+ }
+ s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
+ return
+ }
+ if recorder != nil {
+ recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
+ }
+ if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: summary})
+ } else {
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
+ }
+}
+
+func (s *AccountTestService) fetchSoraTestEndpoint(
+ ctx context.Context,
+ account *Account,
+ authToken string,
+ url string,
+ proxyURL string,
+ enableTLSFingerprint bool,
+) (int, http.Header, []byte, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return 0, nil, nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+authToken)
+ req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Accept-Language", "en-US,en;q=0.9")
+ req.Header.Set("Origin", "https://sora.chatgpt.com")
+ req.Header.Set("Referer", "https://sora.chatgpt.com/")
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint)
+ if err != nil {
+ return 0, nil, nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ return resp.StatusCode, resp.Header, nil, readErr
+ }
+ return resp.StatusCode, resp.Header, body, nil
+}
+
+func parseSoraSubscriptionSummary(body []byte) string {
+ var subResp struct {
+ Data []struct {
+ Plan struct {
+ ID string `json:"id"`
+ Title string `json:"title"`
+ } `json:"plan"`
+ EndTS string `json:"end_ts"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &subResp); err != nil {
+ return ""
+ }
+ if len(subResp.Data) == 0 {
+ return ""
+ }
+
+ first := subResp.Data[0]
+ parts := make([]string, 0, 3)
+ if first.Plan.Title != "" {
+ parts = append(parts, first.Plan.Title)
+ }
+ if first.Plan.ID != "" {
+ parts = append(parts, first.Plan.ID)
+ }
+ if first.EndTS != "" {
+ parts = append(parts, "end="+first.EndTS)
+ }
+ if len(parts) == 0 {
+ return ""
+ }
+ return "Subscription: " + strings.Join(parts, " | ")
+}
+
+func parseSoraInviteSummary(body []byte) string {
+ var inviteResp struct {
+ InviteCode string `json:"invite_code"`
+ RedeemedCount int64 `json:"redeemed_count"`
+ TotalCount int64 `json:"total_count"`
+ }
+ if err := json.Unmarshal(body, &inviteResp); err != nil {
+ return ""
+ }
+
+ parts := []string{"Sora2: supported"}
+ if inviteResp.InviteCode != "" {
+ parts = append(parts, "invite="+inviteResp.InviteCode)
+ }
+ if inviteResp.TotalCount > 0 {
+ parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
+ }
+ return strings.Join(parts, " | ")
+}
+
+func parseSoraRemainingSummary(body []byte) string {
+ var remainingResp struct {
+ RateLimitAndCreditBalance struct {
+ EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
+ RateLimitReached bool `json:"rate_limit_reached"`
+ AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
+ } `json:"rate_limit_and_credit_balance"`
+ }
+ if err := json.Unmarshal(body, &remainingResp); err != nil {
+ return ""
+ }
+ info := remainingResp.RateLimitAndCreditBalance
+ parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
+ if info.RateLimitReached {
+ parts = append(parts, "rate_limited=true")
+ }
+ if info.AccessResetsInSeconds > 0 {
+ parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
+ }
+ return strings.Join(parts, " | ")
+}
+
+func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool {
+ if s == nil || s.cfg == nil {
+ return true
+ }
+ return !s.cfg.Sora.Client.DisableTLSFingerprint
+}
+
+func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
+}
+
+func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
+ return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
+}
+
+func extractCloudflareRayID(headers http.Header, body []byte) string {
+ return soraerror.ExtractCloudflareRayID(headers, body)
+}
+
+func extractSoraEgressIPHint(headers http.Header) string {
+ if headers == nil {
+ return "unknown"
+ }
+ candidates := []string{
+ "x-openai-public-ip",
+ "x-envoy-external-address",
+ "cf-connecting-ip",
+ "x-forwarded-for",
+ }
+ for _, key := range candidates {
+ if value := strings.TrimSpace(headers.Get(key)); value != "" {
+ return value
+ }
+ }
+ return "unknown"
+}
+
+func sanitizeProxyURLForLog(raw string) string {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return ""
+ }
+ u, err := url.Parse(raw)
+ if err != nil {
+ return ""
+ }
+ if u.User != nil {
+ u.User = nil
+ }
+ return u.String()
+}
+
+func endpointPathForLog(endpoint string) string {
+ parsed, err := url.Parse(strings.TrimSpace(endpoint))
+ if err != nil || parsed.Path == "" {
+ return endpoint
+ }
+ return parsed.Path
+}
+
+func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
+ accountID := int64(0)
+ platform := ""
+ proxyID := "none"
+ if account != nil {
+ accountID = account.ID
+ platform = account.Platform
+ if account.ProxyID != nil {
+ proxyID = fmt.Sprintf("%d", *account.ProxyID)
+ }
+ }
+ cfRay := extractCloudflareRayID(headers, body)
+ if cfRay == "" {
+ cfRay = "unknown"
+ }
+ log.Printf(
+ "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
+ accountID,
+ platform,
+ endpoint,
+ endpointPathForLog(endpoint),
+ proxyID,
+ sanitizeProxyURLForLog(proxyURL),
+ cfRay,
+ extractSoraEgressIPHint(headers),
+ )
+}
+
+func truncateSoraErrorBody(body []byte, max int) string {
+ return soraerror.TruncateBody(body, max)
+}
+
// testAntigravityAccountConnection tests an Antigravity account's connection
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go
new file mode 100644
index 00000000..3dfac786
--- /dev/null
+++ b/backend/internal/service/account_test_service_sora_test.go
@@ -0,0 +1,319 @@
+package service
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type queuedHTTPUpstream struct {
+ responses []*http.Response
+ requests []*http.Request
+ tlsFlags []bool
+}
+
+func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, fmt.Errorf("unexpected Do call")
+}
+
+func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
+ u.requests = append(u.requests, req)
+ u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
+ if len(u.responses) == 0 {
+ return nil, fmt.Errorf("no mocked response")
+ }
+ resp := u.responses[0]
+ u.responses = u.responses[1:]
+ return resp, nil
+}
+
+func newJSONResponse(status int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: status,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
+ resp := newJSONResponse(status, body)
+ resp.Header.Set(key, value)
+ return resp
+}
+
+func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+ return c, rec
+}
+
+func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
+ newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
+ newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
+ newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ TLSFingerprint: config.TLSFingerprintConfig{
+ Enabled: true,
+ },
+ },
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ DisableTLSFingerprint: false,
+ },
+ },
+ },
+ }
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ require.Len(t, upstream.requests, 4)
+ require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
+ require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
+ require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
+ require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
+ require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
+ require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
+ require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
+
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"test_start"`)
+ require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
+ require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
+ require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
+ require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"success"`)
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
+ newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
+ newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
+ newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ require.Len(t, upstream.requests, 4)
+ body := rec.Body.String()
+ require.Contains(t, body, "Sora connection OK - User: demo-user")
+ require.Contains(t, body, "Subscription check returned 403")
+ require.Contains(t, body, "Sora2 invite check returned 401")
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"partial_success"`)
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Cloudflare challenge")
+ require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"error"`)
+ require.Contains(t, body, "Cloudflare challenge")
+ require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
+}
+
+func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "Cloudflare challenge")
+ require.Contains(t, err.Error(), "HTTP 429")
+ body := rec.Body.String()
+ require.Contains(t, body, "Cloudflare challenge")
+}
+
+func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "token_invalidated")
+ body := rec.Body.String()
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"status":"failed"`)
+ require.Contains(t, body, "token_invalidated")
+ require.NotContains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
+ },
+ }
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ soraTestCooldown: time.Hour,
+ }
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c1, _ := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c1, account)
+ require.NoError(t, err)
+
+ c2, rec2 := newSoraTestContext()
+ err = svc.testSoraAccountConnection(c2, account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "测试过于频繁")
+ body := rec2.Body.String()
+ require.Contains(t, body, `"type":"sora_test_result"`)
+ require.Contains(t, body, `"code":"test_rate_limited"`)
+ require.Contains(t, body, `"status":"failed"`)
+ require.NotContains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
+ upstream := &queuedHTTPUpstream{
+ responses: []*http.Response{
+ newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
+ newJSONResponse(http.StatusForbidden, `Just a moment...`),
+ newJSONResponse(http.StatusForbidden, `Just a moment...`),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "test_token",
+ },
+ }
+
+ c, rec := newSoraTestContext()
+ err := svc.testSoraAccountConnection(c, account)
+
+ require.NoError(t, err)
+ body := rec.Body.String()
+ require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
+ require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
+ require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
+ require.Contains(t, body, `"type":"test_complete","success":true`)
+}
+
+func TestSanitizeProxyURLForLog(t *testing.T) {
+ require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
+ require.Equal(t, "", sanitizeProxyURLForLog(""))
+ require.Equal(t, "", sanitizeProxyURLForLog("://invalid"))
+}
+
+func TestExtractSoraEgressIPHint(t *testing.T) {
+ h := make(http.Header)
+ h.Set("x-openai-public-ip", "203.0.113.10")
+ require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
+
+ h2 := make(http.Header)
+ h2.Set("x-envoy-external-address", "198.51.100.9")
+ require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
+
+ require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
+ require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index f5130527..8614f24a 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -4,11 +4,15 @@ import (
"context"
"errors"
"fmt"
+ "io"
+ "net/http"
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
)
// AdminService interface defines admin management operations
@@ -39,7 +43,7 @@ type AdminService interface {
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// Account management
- ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
+ ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@@ -65,6 +69,7 @@ type AdminService interface {
GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
+ CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
@@ -288,6 +293,32 @@ type ProxyTestResult struct {
CountryCode string `json:"country_code,omitempty"`
}
+type ProxyQualityCheckResult struct {
+ ProxyID int64 `json:"proxy_id"`
+ Score int `json:"score"`
+ Grade string `json:"grade"`
+ Summary string `json:"summary"`
+ ExitIP string `json:"exit_ip,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+ BaseLatencyMs int64 `json:"base_latency_ms,omitempty"`
+ PassedCount int `json:"passed_count"`
+ WarnCount int `json:"warn_count"`
+ FailedCount int `json:"failed_count"`
+ ChallengeCount int `json:"challenge_count"`
+ CheckedAt int64 `json:"checked_at"`
+ Items []ProxyQualityCheckItem `json:"items"`
+}
+
+type ProxyQualityCheckItem struct {
+ Target string `json:"target"`
+ Status string `json:"status"` // pass/warn/fail/challenge
+ HTTPStatus int `json:"http_status,omitempty"`
+ LatencyMs int64 `json:"latency_ms,omitempty"`
+ Message string `json:"message,omitempty"`
+ CFRay string `json:"cf_ray,omitempty"`
+}
+
// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
IP string
@@ -302,6 +333,58 @@ type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
+type proxyQualityTarget struct {
+ Target string
+ URL string
+ Method string
+ AllowedStatuses map[int]struct{}
+}
+
+var proxyQualityTargets = []proxyQualityTarget{
+ {
+ Target: "openai",
+ URL: "https://api.openai.com/v1/models",
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusUnauthorized: {},
+ },
+ },
+ {
+ Target: "anthropic",
+ URL: "https://api.anthropic.com/v1/messages",
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusUnauthorized: {},
+ http.StatusMethodNotAllowed: {},
+ http.StatusNotFound: {},
+ http.StatusBadRequest: {},
+ },
+ },
+ {
+ Target: "gemini",
+ URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta",
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusOK: {},
+ },
+ },
+ {
+ Target: "sora",
+ URL: "https://sora.chatgpt.com/backend/me",
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusUnauthorized: {},
+ },
+ },
+}
+
+const (
+ proxyQualityRequestTimeout = 15 * time.Second
+ proxyQualityResponseHeaderTimeout = 10 * time.Second
+ proxyQualityMaxBodyBytes = int64(8 * 1024)
+ proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
+)
+
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
@@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
}
// Account management implementations
-func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
+func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
- accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
+ accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID)
if err != nil {
return nil, 0, err
}
@@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
}, nil
}
+func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) {
+ proxy, err := s.proxyRepo.GetByID(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+
+ result := &ProxyQualityCheckResult{
+ ProxyID: id,
+ Score: 100,
+ Grade: "A",
+ CheckedAt: time.Now().Unix(),
+ Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1),
+ }
+
+ proxyURL := proxy.URL()
+ if s.proxyProber == nil {
+ result.Items = append(result.Items, ProxyQualityCheckItem{
+ Target: "base_connectivity",
+ Status: "fail",
+ Message: "代理探测服务未配置",
+ })
+ result.FailedCount++
+ finalizeProxyQualityResult(result)
+ s.saveProxyQualitySnapshot(ctx, id, result, nil)
+ return result, nil
+ }
+
+ exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
+ if err != nil {
+ result.Items = append(result.Items, ProxyQualityCheckItem{
+ Target: "base_connectivity",
+ Status: "fail",
+ LatencyMs: latencyMs,
+ Message: err.Error(),
+ })
+ result.FailedCount++
+ finalizeProxyQualityResult(result)
+ s.saveProxyQualitySnapshot(ctx, id, result, nil)
+ return result, nil
+ }
+
+ result.ExitIP = exitInfo.IP
+ result.Country = exitInfo.Country
+ result.CountryCode = exitInfo.CountryCode
+ result.BaseLatencyMs = latencyMs
+ result.Items = append(result.Items, ProxyQualityCheckItem{
+ Target: "base_connectivity",
+ Status: "pass",
+ LatencyMs: latencyMs,
+ Message: "代理出口连通正常",
+ })
+ result.PassedCount++
+
+ client, err := httpclient.GetClient(httpclient.Options{
+ ProxyURL: proxyURL,
+ Timeout: proxyQualityRequestTimeout,
+ ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
+ ProxyStrict: true,
+ })
+ if err != nil {
+ result.Items = append(result.Items, ProxyQualityCheckItem{
+ Target: "http_client",
+ Status: "fail",
+ Message: fmt.Sprintf("创建检测客户端失败: %v", err),
+ })
+ result.FailedCount++
+ finalizeProxyQualityResult(result)
+ s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
+ return result, nil
+ }
+
+ for _, target := range proxyQualityTargets {
+ item := runProxyQualityTarget(ctx, client, target)
+ result.Items = append(result.Items, item)
+ switch item.Status {
+ case "pass":
+ result.PassedCount++
+ case "warn":
+ result.WarnCount++
+ case "challenge":
+ result.ChallengeCount++
+ default:
+ result.FailedCount++
+ }
+ }
+
+ finalizeProxyQualityResult(result)
+ s.saveProxyQualitySnapshot(ctx, id, result, exitInfo)
+ return result, nil
+}
+
+func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem {
+ item := ProxyQualityCheckItem{
+ Target: target.Target,
+ }
+
+ req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil)
+ if err != nil {
+ item.Status = "fail"
+ item.Message = fmt.Sprintf("构建请求失败: %v", err)
+ return item
+ }
+ req.Header.Set("Accept", "application/json,text/html,*/*")
+ req.Header.Set("User-Agent", proxyQualityClientUserAgent)
+
+ start := time.Now()
+ resp, err := client.Do(req)
+ if err != nil {
+ item.Status = "fail"
+ item.LatencyMs = time.Since(start).Milliseconds()
+ item.Message = fmt.Sprintf("请求失败: %v", err)
+ return item
+ }
+ defer func() { _ = resp.Body.Close() }()
+ item.LatencyMs = time.Since(start).Milliseconds()
+ item.HTTPStatus = resp.StatusCode
+
+ body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1))
+ if readErr != nil {
+ item.Status = "fail"
+ item.Message = fmt.Sprintf("读取响应失败: %v", readErr)
+ return item
+ }
+ if int64(len(body)) > proxyQualityMaxBodyBytes {
+ body = body[:proxyQualityMaxBodyBytes]
+ }
+
+ if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
+ item.Status = "challenge"
+ item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
+ item.Message = "Sora 命中 Cloudflare challenge"
+ return item
+ }
+
+ if _, ok := target.AllowedStatuses[resp.StatusCode]; ok {
+ if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices {
+ item.Status = "pass"
+ item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode)
+ } else {
+ item.Status = "warn"
+ item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode)
+ }
+ return item
+ }
+
+ if resp.StatusCode == http.StatusTooManyRequests {
+ item.Status = "warn"
+ item.Message = "目标返回 429,可能存在频控"
+ return item
+ }
+
+ item.Status = "fail"
+ item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode)
+ return item
+}
+
+func finalizeProxyQualityResult(result *ProxyQualityCheckResult) {
+ if result == nil {
+ return
+ }
+ score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30
+ if score < 0 {
+ score = 0
+ }
+ result.Score = score
+ result.Grade = proxyQualityGrade(score)
+ result.Summary = fmt.Sprintf(
+ "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项",
+ result.PassedCount,
+ result.WarnCount,
+ result.FailedCount,
+ result.ChallengeCount,
+ )
+}
+
+func proxyQualityGrade(score int) string {
+ switch {
+ case score >= 90:
+ return "A"
+ case score >= 75:
+ return "B"
+ case score >= 60:
+ return "C"
+ case score >= 40:
+ return "D"
+ default:
+ return "F"
+ }
+}
+
+func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string {
+ if result == nil {
+ return ""
+ }
+ if result.ChallengeCount > 0 {
+ return "challenge"
+ }
+ if result.FailedCount > 0 {
+ return "failed"
+ }
+ if result.WarnCount > 0 {
+ return "warn"
+ }
+ if result.PassedCount > 0 {
+ return "healthy"
+ }
+ return "failed"
+}
+
+func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string {
+ if result == nil {
+ return ""
+ }
+ for _, item := range result.Items {
+ if item.CFRay != "" {
+ return item.CFRay
+ }
+ }
+ return ""
+}
+
+func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool {
+ if result == nil {
+ return false
+ }
+ for _, item := range result.Items {
+ if item.Target == "base_connectivity" {
+ return item.Status == "pass"
+ }
+ }
+ return false
+}
+
+func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) {
+ if result == nil {
+ return
+ }
+ score := result.Score
+ checkedAt := result.CheckedAt
+ info := &ProxyLatencyInfo{
+ Success: proxyQualityBaseConnectivityPass(result),
+ Message: result.Summary,
+ QualityStatus: proxyQualityOverallStatus(result),
+ QualityScore: &score,
+ QualityGrade: result.Grade,
+ QualitySummary: result.Summary,
+ QualityCheckedAt: &checkedAt,
+ QualityCFRay: proxyQualityFirstCFRay(result),
+ UpdatedAt: time.Now(),
+ }
+ if result.BaseLatencyMs > 0 {
+ latency := result.BaseLatencyMs
+ info.LatencyMs = &latency
+ }
+ if exitInfo != nil {
+ info.IPAddress = exitInfo.IP
+ info.Country = exitInfo.Country
+ info.CountryCode = exitInfo.CountryCode
+ info.Region = exitInfo.Region
+ info.City = exitInfo.City
+ }
+ s.saveProxyLatency(ctx, proxyID, info)
+}
+
func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
if s.proxyProber == nil || proxy == nil {
return
@@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro
proxies[i].CountryCode = info.CountryCode
proxies[i].Region = info.Region
proxies[i].City = info.City
+ proxies[i].QualityStatus = info.QualityStatus
+ proxies[i].QualityScore = info.QualityScore
+ proxies[i].QualityGrade = info.QualityGrade
+ proxies[i].QualitySummary = info.QualitySummary
+ proxies[i].QualityChecked = info.QualityCheckedAt
}
}
@@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64,
if s.proxyLatencyCache == nil || info == nil {
return
}
- if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
+
+ merged := *info
+ if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil {
+ if existing := latencies[proxyID]; existing != nil {
+ if merged.QualityCheckedAt == nil &&
+ merged.QualityScore == nil &&
+ merged.QualityGrade == "" &&
+ merged.QualityStatus == "" &&
+ merged.QualitySummary == "" &&
+ merged.QualityCFRay == "" {
+ merged.QualityStatus = existing.QualityStatus
+ merged.QualityScore = existing.QualityScore
+ merged.QualityGrade = existing.QualityGrade
+ merged.QualitySummary = existing.QualitySummary
+ merged.QualityCheckedAt = existing.QualityCheckedAt
+ merged.QualityCFRay = existing.QualityCFRay
+ }
+ }
+ }
+
+ if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil {
logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err)
}
}
diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go
new file mode 100644
index 00000000..5a43cd9c
--- /dev/null
+++ b/backend/internal/service/admin_service_proxy_quality_test.go
@@ -0,0 +1,95 @@
+package service
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
+ result := &ProxyQualityCheckResult{
+ PassedCount: 2,
+ WarnCount: 1,
+ FailedCount: 1,
+ ChallengeCount: 1,
+ }
+
+ finalizeProxyQualityResult(result)
+
+ require.Equal(t, 38, result.Score)
+ require.Equal(t, "F", result.Grade)
+ require.Contains(t, result.Summary, "通过 2 项")
+ require.Contains(t, result.Summary, "告警 1 项")
+ require.Contains(t, result.Summary, "失败 1 项")
+ require.Contains(t, result.Summary, "挑战 1 项")
+}
+
+func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("cf-ray", "test-ray-123")
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte("Just a moment..."))
+ }))
+ defer server.Close()
+
+ target := proxyQualityTarget{
+ Target: "sora",
+ URL: server.URL,
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusUnauthorized: {},
+ },
+ }
+
+ item := runProxyQualityTarget(context.Background(), server.Client(), target)
+ require.Equal(t, "challenge", item.Status)
+ require.Equal(t, http.StatusForbidden, item.HTTPStatus)
+ require.Equal(t, "test-ray-123", item.CFRay)
+}
+
+func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte(`{"models":[]}`))
+ }))
+ defer server.Close()
+
+ target := proxyQualityTarget{
+ Target: "gemini",
+ URL: server.URL,
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusOK: {},
+ },
+ }
+
+ item := runProxyQualityTarget(context.Background(), server.Client(), target)
+ require.Equal(t, "pass", item.Status)
+ require.Equal(t, http.StatusOK, item.HTTPStatus)
+}
+
+func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusUnauthorized)
+ _, _ = w.Write([]byte(`{"error":"unauthorized"}`))
+ }))
+ defer server.Close()
+
+ target := proxyQualityTarget{
+ Target: "openai",
+ URL: server.URL,
+ Method: http.MethodGet,
+ AllowedStatuses: map[int]struct{}{
+ http.StatusUnauthorized: {},
+ },
+ }
+
+ item := runProxyQualityTarget(context.Background(), server.Client(), target)
+ require.Equal(t, "warn", item.Status)
+ require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
+ require.Contains(t, item.Message, "目标可达")
+}
diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go
index d661b710..ff58fd01 100644
--- a/backend/internal/service/admin_service_search_test.go
+++ b/backend/internal/service/admin_service_search_test.go
@@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct {
listWithFiltersErr error
}
-func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
+func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
@@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
- accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
+ accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index ed33d992..cf87b282 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 {
usage.CacheCreationInputTokens = int(v)
}
+ // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
+ if cc, ok := u["cache_creation"].(map[string]any); ok {
+ if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
+ usage.CacheCreation5mTokens = int(v)
+ }
+ if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
+ usage.CacheCreation1hTokens = int(v)
+ }
+ }
}
// extractClaudeUsage 从非流式 Claude 响应提取 usage
@@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage
if v, ok := u["cache_creation_input_tokens"].(float64); ok {
usage.CacheCreationInputTokens = int(v)
}
+ // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
+ if cc, ok := u["cache_creation"].(map[string]any); ok {
+ if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok {
+ usage.CacheCreation5mTokens = int(v)
+ }
+ if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok {
+ usage.CacheCreation1hTokens = int(v)
+ }
+ }
}
return usage
}
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index e6660399..f100be0b 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -31,8 +31,8 @@ type ModelPricing struct {
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
- CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
- CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
+ CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
+ CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
}
@@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil {
+ // 启用 5m/1h 分类计费的条件:
+ // 1. 存在 1h 价格
+ // 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
+ price5m := litellmPricing.CacheCreationInputTokenCost
+ price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
+ enableBreakdown := price1h > 0 && price1h > price5m
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
- SupportsCacheBreakdown: false,
+ CacheCreation5mPrice: price5m,
+ CacheCreation1hPrice: price1h,
+ SupportsCacheBreakdown: enableBreakdown,
}, nil
}
}
@@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
- // 支持详细缓存分类的模型(5分钟/1小时缓存)
- breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
- float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
+ // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
+ if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
+ // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
+ breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
+ } else {
+ breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
+ float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
+ }
} else {
// 标准缓存创建价格(per-token)
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
@@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
- InputTokens: inRangeInputTokens,
- OutputTokens: tokens.OutputTokens, // 输出只算一次
- CacheCreationTokens: tokens.CacheCreationTokens,
- CacheReadTokens: inRangeCacheTokens,
+ InputTokens: inRangeInputTokens,
+ OutputTokens: tokens.OutputTokens, // 输出只算一次
+ CacheCreationTokens: tokens.CacheCreationTokens,
+ CacheReadTokens: inRangeCacheTokens,
+ CacheCreation5mTokens: tokens.CacheCreation5mTokens,
+ CacheCreation1hTokens: tokens.CacheCreation1hTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go
index bd173b96..5eb278f6 100644
--- a/backend/internal/service/billing_service_test.go
+++ b/backend/internal/service/billing_service_test.go
@@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
SupportsCacheBreakdown: true,
- CacheCreation5mPrice: 4.0, // per million tokens
- CacheCreation1hPrice: 5.0, // per million tokens
+ CacheCreation5mPrice: 4e-6, // per token
+ CacheCreation1hPrice: 5e-6, // per token
},
},
}
@@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
- expected5m := float64(100000) / 1_000_000 * 4.0
- expected1h := float64(50000) / 1_000_000 * 5.0
+ expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
+ expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
}
diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go
index dd58c183..d7108c8d 100644
--- a/backend/internal/service/gateway_beta_test.go
+++ b/backend/internal/service/gateway_beta_test.go
@@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) {
)
require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got)
}
+
+func TestStripBetaToken(t *testing.T) {
+ tests := []struct {
+ name string
+ header string
+ token string
+ want string
+ }{
+ {
+ name: "token in middle",
+ header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14",
+ token: "context-1m-2025-08-07",
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "token at start",
+ header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ token: "context-1m-2025-08-07",
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "token at end",
+ header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07",
+ token: "context-1m-2025-08-07",
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "token not present",
+ header: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ token: "context-1m-2025-08-07",
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "empty header",
+ header: "",
+ token: "context-1m-2025-08-07",
+ want: "",
+ },
+ {
+ name: "with spaces",
+ header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14",
+ token: "context-1m-2025-08-07",
+ want: "oauth-2025-04-20,interleaved-thinking-2025-05-14",
+ },
+ {
+ name: "only token",
+ header: "context-1m-2025-08-07",
+ token: "context-1m-2025-08-07",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := stripBetaToken(tt.header, tt.token)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) {
+ required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}
+ incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20"
+ drop := map[string]struct{}{"context-1m-2025-08-07": {}}
+
+ got := mergeAnthropicBetaDropping(required, incoming, drop)
+ require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got)
+ require.NotContains(t, got, "context-1m-2025-08-07")
+}
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index c7104fde..70d5068b 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
-func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
+func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 0502d352..063a5ae6 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -349,6 +349,8 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
+ CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象)
+ CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象)
}
// ForwardResult 转发结果
@@ -373,9 +375,10 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
- ResponseBody []byte // 上游响应体,用于错误透传规则匹配
- ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
- RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
+ ResponseBody []byte // 上游响应体,用于错误透传规则匹配
+ ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息
+ ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true
+ RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换
}
func (e *UpstreamFailoverError) Error() string {
@@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
- drop := map[string]struct{}{claude.BetaClaudeCode: {}}
+ drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}}
req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
clientBetaHeader := req.Header.Get("anthropic-beta")
- req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
+ req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str
return strings.Join(out, ",")
}
+// stripBetaToken removes a single beta token from a comma-separated header value.
+// It short-circuits when the token is not present to avoid unnecessary allocations.
+func stripBetaToken(header, token string) string {
+ if !strings.Contains(header, token) {
+ return header
+ }
+ out := make([]string, 0, 8)
+ for _, p := range strings.Split(header, ",") {
+ p = strings.TrimSpace(p)
+ if p == "" || p == token {
+ continue
+ }
+ out = append(out, p)
+ }
+ return strings.Join(out, ",")
+}
+
// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers.
// This mirrors opencode-anthropic-auth behavior: do not trust downstream
// headers when using Claude Code-scoped OAuth credentials.
@@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
}
+ // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类
+ if account.IsCacheTTLOverrideEnabled() {
+ overrideTarget := account.GetCacheTTLOverrideTarget()
+ if eventType == "message_start" {
+ if msg, ok := event["message"].(map[string]any); ok {
+ if u, ok := msg["usage"].(map[string]any); ok {
+ rewriteCacheCreationJSON(u, overrideTarget)
+ }
+ }
+ }
+ if eventType == "message_delta" {
+ if u, ok := event["usage"].(map[string]any); ok {
+ rewriteCacheCreationJSON(u, overrideTarget)
+ }
+ }
+ }
+
if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
@@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
+
+ // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
+ cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens")
+ cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens")
+ if cc5m.Exists() || cc1h.Exists() {
+ usage.CacheCreation5mTokens = int(cc5m.Int())
+ usage.CacheCreation1hTokens = int(cc1h.Int())
+ }
}
// 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API)
@@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
+
+ // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
+ cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens")
+ cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens")
+ if cc5m.Exists() && cc5m.Int() > 0 {
+ usage.CacheCreation5mTokens = int(cc5m.Int())
+ }
+ if cc1h.Exists() && cc1h.Int() > 0 {
+ usage.CacheCreation1hTokens = int(cc1h.Int())
+ }
+ }
+}
+
+// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。
+// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。
+func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool {
+ // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别
+ if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 {
+ usage.CacheCreation5mTokens = usage.CacheCreationInputTokens
+ }
+
+ total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens
+ if total == 0 {
+ return false
+ }
+ switch target {
+ case "1h":
+ if usage.CacheCreation1hTokens == total {
+ return false // 已经全是 1h
+ }
+ usage.CacheCreation1hTokens = total
+ usage.CacheCreation5mTokens = 0
+ default: // "5m"
+ if usage.CacheCreation5mTokens == total {
+ return false // 已经全是 5m
+ }
+ usage.CacheCreation5mTokens = total
+ usage.CacheCreation1hTokens = 0
+ }
+ return true
+}
+
+// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。
+// usageObj 是 usage JSON 对象(map[string]any)。
+func rewriteCacheCreationJSON(usageObj map[string]any, target string) {
+ ccObj, ok := usageObj["cache_creation"].(map[string]any)
+ if !ok {
+ return
+ }
+ v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64)
+ v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64)
+ total := v5m + v1h
+ if total == 0 {
+ return
+ }
+ switch target {
+ case "1h":
+ ccObj["ephemeral_1h_input_tokens"] = total
+ ccObj["ephemeral_5m_input_tokens"] = float64(0)
+ default: // "5m"
+ ccObj["ephemeral_5m_input_tokens"] = total
+ ccObj["ephemeral_1h_input_tokens"] = float64(0)
}
}
@@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
return nil, fmt.Errorf("parse response: %w", err)
}
+ // 解析嵌套的 cache_creation 对象中的 5m/1h 明细
+ cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens")
+ cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens")
+ if cc5m.Exists() || cc1h.Exists() {
+ response.Usage.CacheCreation5mTokens = int(cc5m.Int())
+ response.Usage.CacheCreation1hTokens = int(cc1h.Int())
+ }
+
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
@@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
}
}
+ // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类
+ if account.IsCacheTTLOverrideEnabled() {
+ overrideTarget := account.GetCacheTTLOverrideTarget()
+ if applyCacheTTLOverride(&response.Usage, overrideTarget) {
+ // 同步更新 body JSON 中的嵌套 cache_creation 对象
+ if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil {
+ body = newBody
+ }
+ if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil {
+ body = newBody
+ }
+ }
+ }
+
// 如果有模型映射,替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
@@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
result.Usage.InputTokens = 0
}
+ // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ cacheTTLOverridden := false
+ if account.IsCacheTTLOverrideEnabled() {
+ applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
+ cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
+ }
+
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
@@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else {
// Token 计费
tokens := UsageTokens{
- InputTokens: result.Usage.InputTokens,
- OutputTokens: result.Usage.OutputTokens,
- CacheCreationTokens: result.Usage.CacheCreationInputTokens,
- CacheReadTokens: result.Usage.CacheReadInputTokens,
+ InputTokens: result.Usage.InputTokens,
+ OutputTokens: result.Usage.OutputTokens,
+ CacheCreationTokens: result.Usage.CacheCreationInputTokens,
+ CacheReadTokens: result.Usage.CacheReadInputTokens,
+ CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
+ CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
@@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
+ CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
+ CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
@@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
ImageCount: result.ImageCount,
ImageSize: imageSize,
MediaType: mediaType,
+ CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
@@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
result.Usage.InputTokens = 0
}
+ // Cache TTL Override: 确保计费时 token 分类与账号设置一致
+ cacheTTLOverridden := false
+ if account.IsCacheTTLOverrideEnabled() {
+ applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget())
+ cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0
+ }
+
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
@@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
} else {
// Token 计费(使用长上下文计费方法)
tokens := UsageTokens{
- InputTokens: result.Usage.InputTokens,
- OutputTokens: result.Usage.OutputTokens,
- CacheCreationTokens: result.Usage.CacheCreationInputTokens,
- CacheReadTokens: result.Usage.CacheReadInputTokens,
+ InputTokens: result.Usage.InputTokens,
+ OutputTokens: result.Usage.OutputTokens,
+ CacheCreationTokens: result.Usage.CacheCreationInputTokens,
+ CacheReadTokens: result.Usage.CacheReadInputTokens,
+ CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
+ CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
}
var err error
cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier)
@@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
+ CacheCreation5mTokens: result.Usage.CacheCreation5mTokens,
+ CacheCreation1hTokens: result.Usage.CacheCreation1hTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
@@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
+ CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: time.Now(),
}
@@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
incomingBeta := req.Header.Get("anthropic-beta")
requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting}
- req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta))
+ drop := map[string]struct{}{claude.BetaContext1M: {}}
+ req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop))
} else {
clientBetaHeader := req.Header.Get("anthropic-beta")
if clientBetaHeader == "" {
@@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if !strings.Contains(beta, claude.BetaTokenCounting) {
beta = beta + "," + claude.BetaTokenCounting
}
- req.Header.Set("anthropic-beta", beta)
+ req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M))
}
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go
index 50b998a3..cd690cbd 100644
--- a/backend/internal/service/gateway_streaming_test.go
+++ b/backend/internal/service/gateway_streaming_test.go
@@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) {
require.Equal(t, 60, usage.CacheReadInputTokens)
}
+func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) {
+ svc := newMinimalGatewayService()
+ usage := &ClaudeUsage{}
+
+ // 先在 message_start 中写入非零 5m/1h 明细
+ svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage)
+ require.Equal(t, 30, usage.CacheCreation5mTokens)
+ require.Equal(t, 70, usage.CacheCreation1hTokens)
+
+ // 后续 delta 带默认 0,不应覆盖已有非零值
+ svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage)
+ require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细")
+ require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细")
+ require.Equal(t, 12, usage.OutputTokens)
+}
+
func TestParseSSEUsage_InvalidJSON(t *testing.T) {
svc := newMinimalGatewayService()
usage := &ClaudeUsage{}
diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go
index 2d596f33..86bc9476 100644
--- a/backend/internal/service/gemini_multiplatform_test.go
+++ b/backend/internal/service/gemini_multiplatform_test.go
@@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
-func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
+func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go
index e247e654..6f6261d8 100644
--- a/backend/internal/service/oauth_service.go
+++ b/backend/internal/service/oauth_service.go
@@ -14,6 +14,7 @@ import (
type OpenAIOAuthClient interface {
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
+ RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error)
}
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 5764788a..16befb82 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran
result.Modified = true
}
- if _, ok := reqBody["max_output_tokens"]; ok {
- delete(reqBody, "max_output_tokens")
- result.Modified = true
- }
- if _, ok := reqBody["max_completion_tokens"]; ok {
- delete(reqBody, "max_completion_tokens")
- result.Modified = true
+ // Strip parameters unsupported by codex models via the Responses API.
+ for _, key := range []string{
+ "max_output_tokens",
+ "max_completion_tokens",
+ "temperature",
+ "top_p",
+ "frequency_penalty",
+ "presence_penalty",
+ } {
+ if _, ok := reqBody[key]; ok {
+ delete(reqBody, key)
+ result.Modified = true
+ }
}
if normalizeCodexTools(reqBody) {
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index ca7470b9..087ad4ec 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -2,13 +2,20 @@ package service
import (
"context"
+ "crypto/subtle"
+ "encoding/json"
+ "io"
"net/http"
+ "net/url"
+ "strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
+var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
+
// OpenAIOAuthService handles OpenAI OAuth authentication flows
type OpenAIOAuthService struct {
sessionStore *openai.SessionStore
@@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
type OpenAIExchangeCodeInput struct {
SessionID string
Code string
+ State string
RedirectURI string
ProxyID *int64
}
@@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
if !ok {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired")
}
+ if input.State == "" {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required")
+ }
+ if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state")
+ }
// Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL
proxyURL := session.ProxyURL
@@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch
// RefreshToken refreshes an OpenAI OAuth token
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
- tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
+}
+
+// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id.
+func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) {
+ tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err != nil {
return nil, err
}
@@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
return tokenInfo, nil
}
-// RefreshAccountToken refreshes token for an OpenAI account
-func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
- if !account.IsOpenAI() {
- return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account")
+// ExchangeSoraSessionToken exchanges Sora session_token to access_token.
+func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) {
+ if strings.TrimSpace(sessionToken) == "" {
+ return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required")
}
- refreshToken := account.GetOpenAIRefreshToken()
+ proxyURL, err := s.resolveProxyURL(ctx, proxyID)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil)
+ if err != nil {
+ return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err)
+ }
+ req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken))
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Origin", "https://sora.chatgpt.com")
+ req.Header.Set("Referer", "https://sora.chatgpt.com/")
+ req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
+
+ client := newOpenAIOAuthHTTPClient(proxyURL)
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ if resp.StatusCode != http.StatusOK {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
+ }
+
+ var sessionResp struct {
+ AccessToken string `json:"accessToken"`
+ Expires string `json:"expires"`
+ User struct {
+ Email string `json:"email"`
+ Name string `json:"name"`
+ } `json:"user"`
+ }
+ if err := json.Unmarshal(body, &sessionResp); err != nil {
+ return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err)
+ }
+ if strings.TrimSpace(sessionResp.AccessToken) == "" {
+ return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token")
+ }
+
+ expiresAt := time.Now().Add(time.Hour).Unix()
+ if strings.TrimSpace(sessionResp.Expires) != "" {
+ if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil {
+ expiresAt = parsed.Unix()
+ }
+ }
+ expiresIn := expiresAt - time.Now().Unix()
+ if expiresIn < 0 {
+ expiresIn = 0
+ }
+
+ return &OpenAITokenInfo{
+ AccessToken: strings.TrimSpace(sessionResp.AccessToken),
+ ExpiresIn: expiresIn,
+ ExpiresAt: expiresAt,
+ Email: strings.TrimSpace(sessionResp.User.Email),
+ }, nil
+}
+
+// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account
+func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
+ if account.Platform != PlatformOpenAI && account.Platform != PlatformSora {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account")
+ }
+ if account.Type != AccountTypeOAuth {
+ return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account")
+ }
+
+ refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available")
}
@@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
}
}
- return s.RefreshToken(ctx, refreshToken, proxyURL)
+ clientID := account.GetCredential("client_id")
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
// BuildAccountCredentials builds credentials map from token info
@@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo)
func (s *OpenAIOAuthService) Stop() {
s.sessionStore.Stop()
}
+
+func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) {
+ if proxyID == nil {
+ return "", nil
+ }
+ proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
+ if err != nil {
+ return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err)
+ }
+ if proxy == nil {
+ return "", nil
+ }
+ return proxy.URL(), nil
+}
+
+func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{}
+ if strings.TrimSpace(proxyURL) != "" {
+ if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
+ transport.Proxy = http.ProxyURL(parsed)
+ }
+ }
+ return &http.Client{
+ Timeout: 120 * time.Second,
+ Transport: transport,
+ }
+}
diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go
new file mode 100644
index 00000000..fb76f6c1
--- /dev/null
+++ b/backend/internal/service/openai_oauth_service_sora_session_test.go
@@ -0,0 +1,69 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openaiOAuthClientNoopStub struct{}
+
+func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token")
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
+ require.NoError(t, err)
+ require.NotNil(t, info)
+ require.Equal(t, "at-token", info.AccessToken)
+ require.Equal(t, "demo@example.com", info.Email)
+ require.Greater(t, info.ExpiresAt, int64(0))
+}
+
+func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`))
+ }))
+ defer server.Close()
+
+ origin := openAISoraSessionAuthURL
+ openAISoraSessionAuthURL = server.URL
+ defer func() { openAISoraSessionAuthURL = origin }()
+
+ svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{})
+ defer svc.Stop()
+
+ _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "missing access token")
+}
diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go
new file mode 100644
index 00000000..0a2a195f
--- /dev/null
+++ b/backend/internal/service/openai_oauth_service_state_test.go
@@ -0,0 +1,102 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openaiOAuthClientStateStub struct {
+ exchangeCalled int32
+}
+
+func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
+ atomic.AddInt32(&s.exchangeCalled, 1)
+ return &openai.TokenResponse{
+ AccessToken: "at",
+ RefreshToken: "rt",
+ ExpiresIn: 3600,
+ }, nil
+}
+
+func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ return s.RefreshToken(ctx, refreshToken, proxyURL)
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "oauth state is required")
+ require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ State: "wrong-state",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid oauth state")
+ require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
+}
+
+func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) {
+ client := &openaiOAuthClientStateStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ svc.sessionStore.Set("sid", &openai.OAuthSession{
+ State: "expected-state",
+ CodeVerifier: "verifier",
+ RedirectURI: openai.DefaultRedirectURI,
+ CreatedAt: time.Now(),
+ })
+
+ info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{
+ SessionID: "sid",
+ Code: "auth-code",
+ State: "expected-state",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, info)
+ require.Equal(t, "at", info.AccessToken)
+ require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled))
+
+ _, ok := svc.sessionStore.Get("sid")
+ require.False(t, ok)
+}
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
index 3842f0a4..a8a6b96c 100644
--- a/backend/internal/service/openai_token_provider.go
+++ b/backend/internal/service/openai_token_provider.go
@@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
- if p.openAIOAuthService == nil {
+ if account.Platform == PlatformSora {
+ slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID)
+ // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
+ refreshFailed = true
+ } else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败
@@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
// 仅在 expires_at 已过期/接近过期时才执行无锁刷新
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
- if p.openAIOAuthService == nil {
+ if account.Platform == PlatformSora {
+ slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID)
+ // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。
+ refreshFailed = true
+ } else if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go
index f6541d08..92b37e73 100644
--- a/backend/internal/service/ops_concurrency.go
+++ b/backend/internal/service/ops_concurrency.go
@@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
Page: page,
PageSize: opsAccountsPageSize,
- }, platformFilter, "", "", "")
+ }, platformFilter, "", "", "", 0)
if err != nil {
return nil, err
}
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index ee979d84..41e8b5eb 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -28,14 +28,15 @@ var (
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
- InputCostPerToken float64 `json:"input_cost_per_token"`
- OutputCostPerToken float64 `json:"output_cost_per_token"`
- CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
- CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
- LiteLLMProvider string `json:"litellm_provider"`
- Mode string `json:"mode"`
- SupportsPromptCaching bool `json:"supports_prompt_caching"`
- OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
+ InputCostPerToken float64 `json:"input_cost_per_token"`
+ OutputCostPerToken float64 `json:"output_cost_per_token"`
+ CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
+ CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"`
+ CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
+ LiteLLMProvider string `json:"litellm_provider"`
+ Mode string `json:"mode"`
+ SupportsPromptCaching bool `json:"supports_prompt_caching"`
+ OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
}
// PricingRemoteClient 远程价格数据获取接口
@@ -46,14 +47,15 @@ type PricingRemoteClient interface {
// LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct {
- InputCostPerToken *float64 `json:"input_cost_per_token"`
- OutputCostPerToken *float64 `json:"output_cost_per_token"`
- CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
- CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
- LiteLLMProvider string `json:"litellm_provider"`
- Mode string `json:"mode"`
- SupportsPromptCaching bool `json:"supports_prompt_caching"`
- OutputCostPerImage *float64 `json:"output_cost_per_image"`
+ InputCostPerToken *float64 `json:"input_cost_per_token"`
+ OutputCostPerToken *float64 `json:"output_cost_per_token"`
+ CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
+ CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"`
+ CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
+ LiteLLMProvider string `json:"litellm_provider"`
+ Mode string `json:"mode"`
+ SupportsPromptCaching bool `json:"supports_prompt_caching"`
+ OutputCostPerImage *float64 `json:"output_cost_per_image"`
}
// PricingService 动态价格服务
@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
}
+ if entry.CacheCreationInputTokenCostAbove1hr != nil {
+ pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr
+ }
if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
}
diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go
index 7eb7728f..fc449091 100644
--- a/backend/internal/service/proxy.go
+++ b/backend/internal/service/proxy.go
@@ -40,6 +40,11 @@ type ProxyWithAccountCount struct {
CountryCode string
Region string
City string
+ QualityStatus string
+ QualityScore *int
+ QualityGrade string
+ QualitySummary string
+ QualityChecked *int64
}
type ProxyAccountSummary struct {
diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go
index 4a1cc77b..f54bff88 100644
--- a/backend/internal/service/proxy_latency_cache.go
+++ b/backend/internal/service/proxy_latency_cache.go
@@ -6,15 +6,21 @@ import (
)
type ProxyLatencyInfo struct {
- Success bool `json:"success"`
- LatencyMs *int64 `json:"latency_ms,omitempty"`
- Message string `json:"message,omitempty"`
- IPAddress string `json:"ip_address,omitempty"`
- Country string `json:"country,omitempty"`
- CountryCode string `json:"country_code,omitempty"`
- Region string `json:"region,omitempty"`
- City string `json:"city,omitempty"`
- UpdatedAt time.Time `json:"updated_at"`
+ Success bool `json:"success"`
+ LatencyMs *int64 `json:"latency_ms,omitempty"`
+ Message string `json:"message,omitempty"`
+ IPAddress string `json:"ip_address,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+ QualityStatus string `json:"quality_status,omitempty"`
+ QualityScore *int `json:"quality_score,omitempty"`
+ QualityGrade string `json:"quality_grade,omitempty"`
+ QualitySummary string `json:"quality_summary,omitempty"`
+ QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"`
+ QualityCFRay string `json:"quality_cf_ray,omitempty"`
+ UpdatedAt time.Time `json:"updated_at"`
}
type ProxyLatencyCache interface {
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 12c48ab8..b1d767fc 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
}
}
- // 2. 尝试从响应头解析重置时间(Anthropic)
+ // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
+ if result := calculateAnthropic429ResetTime(headers); result != nil {
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
+ slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
+ return
+ }
+
+ // 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推
+ windowEnd := result.resetAt
+ if result.fiveHourReset != nil {
+ windowEnd = *result.fiveHourReset
+ }
+ windowStart := windowEnd.Add(-5 * time.Hour)
+ if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
+ slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
+ }
+
+ slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second))
+ return
+ }
+
+ // 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容)
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
- // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
+ // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
if resetTimestamp == "" {
switch account.Platform {
case PlatformOpenAI:
@@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim
return nil
}
+// anthropic429Result holds the parsed Anthropic 429 rate-limit information.
+type anthropic429Result struct {
+ resetAt time.Time // The correct reset time to use for SetRateLimited
+ fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available
+}
+
+// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers
+// to determine which window (5h or 7d) actually triggered the 429.
+//
+// Headers used:
+// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold
+// - anthropic-ratelimit-unified-5h-reset
+// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold
+// - anthropic-ratelimit-unified-7d-reset
+//
+// Returns nil when the per-window headers are absent (caller should fall back to
+// the aggregated anthropic-ratelimit-unified-reset header).
+func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result {
+ reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset")
+ reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset")
+
+ if reset5hStr == "" && reset7dStr == "" {
+ return nil
+ }
+
+ var reset5h, reset7d *time.Time
+ if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil {
+ t := time.Unix(ts, 0)
+ reset5h = &t
+ }
+ if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil {
+ t := time.Unix(ts, 0)
+ reset7d = &t
+ }
+
+ is5hExceeded := isAnthropicWindowExceeded(headers, "5h")
+ is7dExceeded := isAnthropicWindowExceeded(headers, "7d")
+
+ slog.Info("anthropic_429_window_analysis",
+ "is_5h_exceeded", is5hExceeded,
+ "is_7d_exceeded", is7dExceeded,
+ "reset_5h", reset5hStr,
+ "reset_7d", reset7dStr,
+ )
+
+ // Select the correct reset time based on which window(s) are exceeded.
+ var chosen *time.Time
+ switch {
+ case is5hExceeded && is7dExceeded:
+ // Both exceeded → prefer 7d (longer cooldown), fall back to 5h
+ chosen = reset7d
+ if chosen == nil {
+ chosen = reset5h
+ }
+ case is5hExceeded:
+ chosen = reset5h
+ case is7dExceeded:
+ chosen = reset7d
+ default:
+ // Neither flag clearly exceeded — pick the sooner reset as best guess
+ chosen = pickSooner(reset5h, reset7d)
+ }
+
+ if chosen == nil {
+ return nil
+ }
+ return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h}
+}
+
+// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window
+// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers.
+func isAnthropicWindowExceeded(headers http.Header, window string) bool {
+ prefix := "anthropic-ratelimit-unified-" + window + "-"
+
+ // Check surpassed-threshold first (most explicit signal)
+ if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") {
+ return true
+ }
+
+ // Fall back to utilization >= 1.0
+ if utilStr := headers.Get(prefix + "utilization"); utilStr != "" {
+ if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 {
+ // Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0
+ return true
+ }
+ }
+
+ return false
+}
+
+// pickSooner returns whichever of the two time pointers is earlier.
+// If only one is non-nil, it is returned. If both are nil, returns nil.
+func pickSooner(a, b *time.Time) *time.Time {
+ switch {
+ case a != nil && b != nil:
+ if a.Before(*b) {
+ return a
+ }
+ return b
+ case a != nil:
+ return a
+ default:
+ return b
+ }
+}
+
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
// OpenAI 的 usage_limit_reached 错误格式:
//
diff --git a/backend/internal/service/ratelimit_service_anthropic_test.go b/backend/internal/service/ratelimit_service_anthropic_test.go
new file mode 100644
index 00000000..eaeaf30e
--- /dev/null
+++ b/backend/internal/service/ratelimit_service_anthropic_test.go
@@ -0,0 +1,202 @@
+package service
+
+import (
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1770998400)
+
+ if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
+ t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
+ }
+}
+
+func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1771549200)
+
+ // fiveHourReset should still be populated for session window calculation
+ if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) {
+ t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset)
+ }
+}
+
+func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1771549200)
+}
+
+func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ if result != nil {
+ t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt)
+ }
+}
+
+func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) {
+ result := calculateAnthropic429ResetTime(http.Header{})
+ if result != nil {
+ t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt)
+ }
+}
+
+func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+ headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1770998400)
+}
+
+func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1770998400)
+}
+
+func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1770998400)
+}
+
+func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05")
+ headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1770998400)
+}
+
+func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) {
+ headers := http.Header{}
+ headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03")
+ headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200")
+
+ result := calculateAnthropic429ResetTime(headers)
+ assertAnthropicResult(t, result, 1771549200)
+
+ if result.fiveHourReset != nil {
+ t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset)
+ }
+}
+
+func TestIsAnthropicWindowExceeded(t *testing.T) {
+ tests := []struct {
+ name string
+ headers http.Header
+ window string
+ expected bool
+ }{
+ {
+ name: "utilization above 1.0",
+ headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"),
+ window: "5h",
+ expected: true,
+ },
+ {
+ name: "utilization exactly 1.0",
+ headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"),
+ window: "5h",
+ expected: true,
+ },
+ {
+ name: "utilization below 1.0",
+ headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"),
+ window: "5h",
+ expected: false,
+ },
+ {
+ name: "surpassed-threshold true",
+ headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"),
+ window: "7d",
+ expected: true,
+ },
+ {
+ name: "surpassed-threshold True (case insensitive)",
+ headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"),
+ window: "7d",
+ expected: true,
+ },
+ {
+ name: "surpassed-threshold false",
+ headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"),
+ window: "7d",
+ expected: false,
+ },
+ {
+ name: "no headers",
+ headers: http.Header{},
+ window: "5h",
+ expected: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := isAnthropicWindowExceeded(tc.headers, tc.window)
+ if got != tc.expected {
+ t.Errorf("expected %v, got %v", tc.expected, got)
+ }
+ })
+ }
+}
+
+// assertAnthropicResult is a test helper that verifies the result is non-nil and
+// has the expected resetAt unix timestamp.
+func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) {
+ t.Helper()
+ if result == nil {
+ t.Fatal("expected non-nil result")
+ return // unreachable, but satisfies staticcheck SA5011
+ }
+ want := time.Unix(wantUnix, 0)
+ if !result.resetAt.Equal(want) {
+ t.Errorf("expected resetAt=%v, got %v", want, result.resetAt)
+ }
+}
+
+func makeHeader(key, value string) http.Header {
+ h := http.Header{}
+ h.Set(key, value)
+ return h
+}
diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go
index de097d5e..7cecfa03 100644
--- a/backend/internal/service/sora_client.go
+++ b/backend/internal/service/sora_client.go
@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "hash/fnv"
"io"
"log"
"math/rand"
@@ -17,12 +18,16 @@ import (
"net/textproto"
"net/url"
"path"
+ "sort"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/Wei-Shaw/sub2api/internal/util/logredact"
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/google/uuid"
"github.com/tidwall/gjson"
"golang.org/x/crypto/sha3"
@@ -34,6 +39,11 @@ const (
soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)"
)
+var (
+ soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session"
+ soraOAuthTokenURL = "https://auth.openai.com/oauth/token"
+)
+
const (
soraPowMaxIteration = 500000
)
@@ -86,9 +96,20 @@ var soraDesktopUserAgents = []string{
"Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
+var soraMobileUserAgents = []string{
+ "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)",
+ "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)",
+ "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)",
+ "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)",
+ "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)",
+ "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)",
+ "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)",
+}
+
var soraRand = rand.New(rand.NewSource(time.Now().UnixNano()))
var soraRandMu sync.Mutex
var soraPerfStart = time.Now()
+var soraPowTokenGenerator = soraGetPowToken
// SoraClient 定义直连 Sora 的任务操作接口。
type SoraClient interface {
@@ -96,6 +117,18 @@ type SoraClient interface {
UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error)
CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error)
CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error)
+ CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error)
+ UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error)
+ GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error)
+ DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error)
+ UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error)
+ FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error)
+ SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error
+ DeleteCharacter(ctx context.Context, account *Account, characterID string) error
+ PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error)
+ DeletePost(ctx context.Context, account *Account, postID string) error
+ GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error)
+ EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error)
GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error)
GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error)
}
@@ -117,6 +150,17 @@ type SoraVideoRequest struct {
Size string
MediaID string
RemixTargetID string
+ CameoIDs []string
+}
+
+// SoraStoryboardRequest 分镜视频生成请求参数
+type SoraStoryboardRequest struct {
+ Prompt string
+ Orientation string
+ Frames int
+ Model string
+ Size string
+ MediaID string
}
// SoraImageTaskStatus 图片任务状态
@@ -130,11 +174,32 @@ type SoraImageTaskStatus struct {
// SoraVideoTaskStatus 视频任务状态
type SoraVideoTaskStatus struct {
- ID string
- Status string
- ProgressPct int
- URLs []string
- ErrorMsg string
+ ID string
+ Status string
+ ProgressPct int
+ URLs []string
+ GenerationID string
+ ErrorMsg string
+}
+
+// SoraCameoStatus 角色处理中间态
+type SoraCameoStatus struct {
+ Status string
+ StatusMessage string
+ DisplayNameHint string
+ UsernameHint string
+ ProfileAssetURL string
+ InstructionSetHint any
+ InstructionSet any
+}
+
+// SoraCharacterFinalizeRequest 角色定稿请求参数
+type SoraCharacterFinalizeRequest struct {
+ CameoID string
+ Username string
+ DisplayName string
+ ProfileAssetPointer string
+ InstructionSet any
}
// SoraUpstreamError 上游错误
@@ -157,26 +222,110 @@ func (e *SoraUpstreamError) Error() string {
// SoraDirectClient 直连 Sora 实现
type SoraDirectClient struct {
- cfg *config.Config
- httpUpstream HTTPUpstream
- tokenProvider *OpenAITokenProvider
+ cfg *config.Config
+ httpUpstream HTTPUpstream
+ tokenProvider *OpenAITokenProvider
+ accountRepo AccountRepository
+ soraAccountRepo SoraAccountRepository
+ baseURL string
+ challengeCooldownMu sync.RWMutex
+ challengeCooldowns map[string]soraChallengeCooldownEntry
+ sidecarSessionMu sync.RWMutex
+ sidecarSessions map[string]soraSidecarSessionEntry
+}
+
+type soraRequestTraceContextKey struct{}
+
+type soraRequestTrace struct {
+ ID string
+ ProxyKey string
+ UAHash string
}
// NewSoraDirectClient 创建 Sora 直连客户端
func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient {
- return &SoraDirectClient{
- cfg: cfg,
- httpUpstream: httpUpstream,
- tokenProvider: tokenProvider,
+ baseURL := ""
+ if cfg != nil {
+ rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/")
+ baseURL = normalizeSoraBaseURL(rawBaseURL)
+ if rawBaseURL != "" && baseURL != rawBaseURL {
+ log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL))
+ }
}
+ return &SoraDirectClient{
+ cfg: cfg,
+ httpUpstream: httpUpstream,
+ tokenProvider: tokenProvider,
+ baseURL: baseURL,
+ challengeCooldowns: make(map[string]soraChallengeCooldownEntry),
+ sidecarSessions: make(map[string]soraSidecarSessionEntry),
+ }
+}
+
+func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) {
+ if c == nil {
+ return
+ }
+ c.accountRepo = accountRepo
+ c.soraAccountRepo = soraAccountRepo
}
// Enabled 判断是否启用 Sora 直连
func (c *SoraDirectClient) Enabled() bool {
- if c == nil || c.cfg == nil {
+ if c == nil {
return false
}
- return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != ""
+ if strings.TrimSpace(c.baseURL) != "" {
+ return true
+ }
+ if c.cfg == nil {
+ return false
+ }
+ return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != ""
+}
+
+// PreflightCheck 在创建任务前执行账号能力预检。
+// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。
+func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error {
+ if modelCfg.Type != "video" {
+ return nil
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Accept", "application/json")
+ body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false)
+ if err != nil {
+ var upstreamErr *SoraUpstreamError
+ if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound {
+ return &SoraUpstreamError{
+ StatusCode: http.StatusForbidden,
+ Message: "当前账号未开通 Sora2 能力或无可用配额",
+ Headers: upstreamErr.Headers,
+ Body: upstreamErr.Body,
+ }
+ }
+ return err
+ }
+
+ rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool()
+ remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining")
+ if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) {
+ msg := "当前账号 Sora2 可用配额不足"
+ if requestedModel != "" {
+ msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel)
+ }
+ return &SoraUpstreamError{
+ StatusCode: http.StatusTooManyRequests,
+ Message: msg,
+ Headers: http.Header{},
+ }
+ }
+ return nil
}
func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) {
@@ -187,6 +336,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
if err != nil {
return "", err
}
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
if filename == "" {
filename = "image.png"
}
@@ -213,10 +364,10 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
return "", err
}
- headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", writer.FormDataContentType())
- respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false)
if err != nil {
return "", err
}
@@ -232,6 +383,9 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
operation := "simple_compose"
inpaintItems := []map[string]any{}
if strings.TrimSpace(req.MediaID) != "" {
@@ -252,7 +406,7 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
"n_frames": 1,
"inpaint_items": inpaintItems,
}
- headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
@@ -261,13 +415,13 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
- sentinel, err := c.generateSentinelToken(ctx, account, token)
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
- respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
@@ -283,6 +437,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
orientation := req.Orientation
if orientation == "" {
orientation = "landscape"
@@ -320,9 +477,12 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
payload["remix_target_id"] = req.RemixTargetID
payload["cameo_ids"] = []string{}
payload["cameo_replacements"] = map[string]any{}
+ } else if len(req.CameoIDs) > 0 {
+ payload["cameo_ids"] = req.CameoIDs
+ payload["cameo_replacements"] = map[string]any{}
}
- headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ headers := c.buildBaseHeaders(token, userAgent)
headers.Set("Content-Type", "application/json")
headers.Set("Origin", "https://sora.chatgpt.com")
headers.Set("Referer", "https://sora.chatgpt.com/")
@@ -330,13 +490,13 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
if err != nil {
return "", err
}
- sentinel, err := c.generateSentinelToken(ctx, account, token)
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
if err != nil {
return "", err
}
headers.Set("openai-sentinel-token", sentinel)
- respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
@@ -347,6 +507,469 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
return taskID, nil
}
+func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
+ orientation := req.Orientation
+ if orientation == "" {
+ orientation = "landscape"
+ }
+ nFrames := req.Frames
+ if nFrames <= 0 {
+ nFrames = 450
+ }
+ model := req.Model
+ if model == "" {
+ model = "sy_8"
+ }
+ size := req.Size
+ if size == "" {
+ size = "small"
+ }
+
+ inpaintItems := []map[string]any{}
+ if strings.TrimSpace(req.MediaID) != "" {
+ inpaintItems = append(inpaintItems, map[string]any{
+ "kind": "upload",
+ "upload_id": req.MediaID,
+ })
+ }
+ payload := map[string]any{
+ "kind": "video",
+ "prompt": req.Prompt,
+ "title": "Draft your video",
+ "orientation": orientation,
+ "size": size,
+ "n_frames": nFrames,
+ "storyboard_id": nil,
+ "inpaint_items": inpaintItems,
+ "remix_target_id": nil,
+ "model": model,
+ "metadata": nil,
+ "style_id": nil,
+ "cameo_ids": nil,
+ "cameo_replacements": nil,
+ "audio_caption": nil,
+ "audio_transcript": nil,
+ "video_caption": nil,
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
+ if err != nil {
+ return "", err
+ }
+ headers.Set("openai-sentinel-token", sentinel)
+
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true)
+ if err != nil {
+ return "", err
+ }
+ taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
+ if taskID == "" {
+ return "", errors.New("storyboard task response missing id")
+ }
+ return taskID, nil
+}
+
+func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
+ if len(data) == 0 {
+ return "", errors.New("empty video data")
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ partHeader := make(textproto.MIMEHeader)
+ partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`)
+ partHeader.Set("Content-Type", "video/mp4")
+ part, err := writer.CreatePart(partHeader)
+ if err != nil {
+ return "", err
+ }
+ if _, err := part.Write(data); err != nil {
+ return "", err
+ }
+ if err := writer.WriteField("timestamps", "0,3"); err != nil {
+ return "", err
+ }
+ if err := writer.Close(); err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", writer.FormDataContentType())
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false)
+ if err != nil {
+ return "", err
+ }
+ cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
+ if cameoID == "" {
+ return "", errors.New("character upload response missing id")
+ }
+ return cameoID, nil
+}
+
+func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ respBody, _, err := c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodGet,
+ c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)),
+ headers,
+ nil,
+ false,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return &SoraCameoStatus{
+ Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()),
+ StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()),
+ DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()),
+ UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()),
+ ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()),
+ InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(),
+ InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(),
+ }, nil
+}
+
+func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Accept", "image/*,*/*;q=0.8")
+
+ respBody, _, err := c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodGet,
+ strings.TrimSpace(imageURL),
+ headers,
+ nil,
+ false,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return respBody, nil
+}
+
+func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
+ if len(data) == 0 {
+ return "", errors.New("empty character image")
+ }
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+
+ var body bytes.Buffer
+ writer := multipart.NewWriter(&body)
+ partHeader := make(textproto.MIMEHeader)
+ partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`)
+ partHeader.Set("Content-Type", "image/webp")
+ part, err := writer.CreatePart(partHeader)
+ if err != nil {
+ return "", err
+ }
+ if _, err := part.Write(data); err != nil {
+ return "", err
+ }
+ if err := writer.WriteField("use_case", "profile"); err != nil {
+ return "", err
+ }
+ if err := writer.Close(); err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", writer.FormDataContentType())
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false)
+ if err != nil {
+ return "", err
+ }
+ assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String())
+ if assetPointer == "" {
+ return "", errors.New("character image upload response missing asset_pointer")
+ }
+ return assetPointer, nil
+}
+
+func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
+ payload := map[string]any{
+ "cameo_id": req.CameoID,
+ "username": req.Username,
+ "display_name": req.DisplayName,
+ "profile_asset_pointer": req.ProfileAssetPointer,
+ "instruction_set": nil,
+ "safety_instruction_set": nil,
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false)
+ if err != nil {
+ return "", err
+ }
+ characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String())
+ if characterID == "" {
+ return "", errors.New("character finalize response missing character_id")
+ }
+ return characterID, nil
+}
+
+func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ payload := map[string]any{"visibility": "public"}
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodPost,
+ c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"),
+ headers,
+ bytes.NewReader(body),
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodDelete,
+ c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)),
+ headers,
+ nil,
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent)
+ payload := map[string]any{
+ "attachments_to_create": []map[string]any{
+ {
+ "generation_id": generationID,
+ "kind": "sora",
+ },
+ },
+ "post_text": "",
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL)
+ if err != nil {
+ return "", err
+ }
+ headers.Set("openai-sentinel-token", sentinel)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true)
+ if err != nil {
+ return "", err
+ }
+ postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String())
+ if postID == "" {
+ return "", errors.New("watermark-free publish response missing post.id")
+ }
+ return postID, nil
+}
+
+func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
+ _, _, err = c.doRequestWithProxy(
+ ctx,
+ account,
+ proxyURL,
+ http.MethodDelete,
+ c.buildURL("/project_y/post/"+strings.TrimSpace(postID)),
+ headers,
+ nil,
+ false,
+ )
+ return err
+}
+
+func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
+ parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/")
+ if parseURL == "" {
+ return "", errors.New("custom parse url is required")
+ }
+ if strings.TrimSpace(parseToken) == "" {
+ return "", errors.New("custom parse token is required")
+ }
+ shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID)
+ payload := map[string]any{
+ "url": shareURL,
+ "token": strings.TrimSpace(parseToken),
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body))
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ proxyURL := c.resolveProxyURL(account)
+ accountID := int64(0)
+ accountConcurrency := 0
+ if account != nil {
+ accountID = account.ID
+ accountConcurrency = account.Concurrency
+ }
+ var resp *http.Response
+ if c.httpUpstream != nil {
+ resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency)
+ } else {
+ resp, err = http.DefaultClient.Do(req)
+ }
+ if err != nil {
+ return "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
+ if err != nil {
+ return "", err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256))
+ }
+ downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String())
+ if downloadLink == "" {
+ return "", errors.New("custom parse response missing download_link")
+ }
+ return downloadLink, nil
+}
+
+func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
+ token, err := c.getAccessToken(ctx, account)
+ if err != nil {
+ return "", err
+ }
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ if strings.TrimSpace(expansionLevel) == "" {
+ expansionLevel = "medium"
+ }
+ if durationS <= 0 {
+ durationS = 10
+ }
+
+ payload := map[string]any{
+ "prompt": prompt,
+ "expansion_level": expansionLevel,
+ "duration_s": durationS,
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return "", err
+ }
+
+ headers := c.buildBaseHeaders(token, userAgent)
+ headers.Set("Content-Type", "application/json")
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false)
+ if err != nil {
+ return "", err
+ }
+ enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String())
+ if enhancedPrompt == "" {
+ return "", errors.New("enhance_prompt response missing enhanced_prompt")
+ }
+ return enhancedPrompt, nil
+}
+
func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit())
if err != nil {
@@ -373,12 +996,14 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac
if err != nil {
return nil, false, err
}
- headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
if limit <= 0 {
limit = 20
}
endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit)
- respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false)
if err != nil {
return nil, false, err
}
@@ -435,9 +1060,11 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if err != nil {
return nil, err
}
- headers := c.buildBaseHeaders(token, c.defaultUserAgent())
+ userAgent := c.taskUserAgent()
+ proxyURL := c.resolveProxyURL(account)
+ headers := c.buildBaseHeaders(token, userAgent)
- respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false)
if err != nil {
return nil, err
}
@@ -466,7 +1093,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
}
}
- respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
+ respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false)
if err != nil {
return nil, err
}
@@ -475,6 +1102,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
if draft.Get("task_id").String() != taskID {
return true
}
+ generationID := strings.TrimSpace(draft.Get("id").String())
kind := strings.TrimSpace(draft.Get("kind").String())
reason := strings.TrimSpace(draft.Get("reason_str").String())
if reason == "" {
@@ -491,15 +1119,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
msg = "Content violates guardrails"
}
draftFound = &SoraVideoTaskStatus{
- ID: taskID,
- Status: "failed",
- ErrorMsg: msg,
+ ID: taskID,
+ Status: "failed",
+ GenerationID: generationID,
+ ErrorMsg: msg,
}
} else {
draftFound = &SoraVideoTaskStatus{
- ID: taskID,
- Status: "completed",
- URLs: []string{urlStr},
+ ID: taskID,
+ Status: "completed",
+ GenerationID: generationID,
+ URLs: []string{urlStr},
}
}
return false
@@ -512,9 +1142,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
}
func (c *SoraDirectClient) buildURL(endpoint string) string {
- base := ""
- if c != nil && c.cfg != nil {
- base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/")
+ base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/")
+ if base == "" && c != nil && c.cfg != nil {
+ base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)
+ c.baseURL = base
}
if base == "" {
return endpoint
@@ -536,18 +1167,278 @@ func (c *SoraDirectClient) defaultUserAgent() string {
return ua
}
+func (c *SoraDirectClient) taskUserAgent() string {
+ if c != nil && c.cfg != nil {
+ if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" {
+ return ua
+ }
+ }
+ if len(soraMobileUserAgents) > 0 {
+ return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))]
+ }
+ if len(soraDesktopUserAgents) > 0 {
+ return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))]
+ }
+ return soraDefaultUserAgent
+}
+
+func (c *SoraDirectClient) resolveProxyURL(account *Account) string {
+ if account == nil || account.ProxyID == nil || account.Proxy == nil {
+ return ""
+ }
+ return strings.TrimSpace(account.Proxy.URL())
+}
+
func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
- if c.tokenProvider != nil {
- return c.tokenProvider.GetAccessToken(ctx, account)
+
+ allowProvider := c.allowOpenAITokenProvider(account)
+ var providerErr error
+ if allowProvider && c.tokenProvider != nil {
+ token, err := c.tokenProvider.GetAccessToken(ctx, account)
+ if err == nil && strings.TrimSpace(token) != "" {
+ c.logTokenSource(account, "openai_token_provider")
+ return token, nil
+ }
+ providerErr = err
+ if err != nil && c.debugEnabled() {
+ c.debugLogf(
+ "token_provider_failed account_id=%d platform=%s err=%s",
+ account.ID,
+ account.Platform,
+ logredact.RedactText(err.Error()),
+ )
+ }
}
token := strings.TrimSpace(account.GetCredential("access_token"))
- if token == "" {
- return "", errors.New("access_token not found")
+ if token != "" {
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute {
+ refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring")
+ if refreshErr == nil && strings.TrimSpace(refreshed) != "" {
+ c.logTokenSource(account, "refresh_token_recovered")
+ return refreshed, nil
+ }
+ if refreshErr != nil && c.debugEnabled() {
+ c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error()))
+ }
+ }
+ c.logTokenSource(account, "account_credentials")
+ return token, nil
}
- return token, nil
+
+ recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing")
+ if recoverErr == nil && strings.TrimSpace(recovered) != "" {
+ c.logTokenSource(account, "session_or_refresh_recovered")
+ return recovered, nil
+ }
+ if recoverErr != nil && c.debugEnabled() {
+ c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error()))
+ }
+ if providerErr != nil {
+ return "", providerErr
+ }
+ if c.tokenProvider != nil && !allowProvider {
+ c.logTokenSource(account, "account_credentials(provider_disabled)")
+ }
+ return "", errors.New("access_token not found")
+}
+
+func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+
+ if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" {
+ accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken)
+ if err == nil && strings.TrimSpace(accessToken) != "" {
+ c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken)
+ c.logTokenRecover(account, "session_token", reason, true, nil)
+ return accessToken, nil
+ }
+ c.logTokenRecover(account, "session_token", reason, false, err)
+ }
+
+ refreshToken := strings.TrimSpace(account.GetCredential("refresh_token"))
+ if refreshToken == "" {
+ return "", errors.New("session_token/refresh_token not found")
+ }
+ accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken)
+ if err != nil {
+ c.logTokenRecover(account, "refresh_token", reason, false, err)
+ return "", err
+ }
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("refreshed access_token is empty")
+ }
+ c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "")
+ c.logTokenRecover(account, "refresh_token", reason, true, nil)
+ return accessToken, nil
+}
+
+func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) {
+ headers := http.Header{}
+ headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken)
+ headers.Set("Accept", "application/json")
+ headers.Set("Origin", "https://sora.chatgpt.com")
+ headers.Set("Referer", "https://sora.chatgpt.com/")
+ headers.Set("User-Agent", c.defaultUserAgent())
+ body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false)
+ if err != nil {
+ return "", "", err
+ }
+ accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String())
+ if accessToken == "" {
+ return "", "", errors.New("session exchange missing accessToken")
+ }
+ expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String())
+ return accessToken, expiresAt, nil
+}
+
+func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) {
+ clientIDs := []string{
+ strings.TrimSpace(account.GetCredential("client_id")),
+ openaioauth.SoraClientID,
+ openaioauth.ClientID,
+ }
+ tried := make(map[string]struct{}, len(clientIDs))
+ var lastErr error
+
+ for _, clientID := range clientIDs {
+ if clientID == "" {
+ continue
+ }
+ if _, ok := tried[clientID]; ok {
+ continue
+ }
+ tried[clientID] = struct{}{}
+
+ formData := url.Values{}
+ formData.Set("client_id", clientID)
+ formData.Set("grant_type", "refresh_token")
+ formData.Set("refresh_token", refreshToken)
+ formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback")
+ headers := http.Header{}
+ headers.Set("Accept", "application/json")
+ headers.Set("Content-Type", "application/x-www-form-urlencoded")
+ headers.Set("User-Agent", c.defaultUserAgent())
+
+ respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false)
+ if err != nil {
+ lastErr = err
+ if c.debugEnabled() {
+ c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error()))
+ }
+ continue
+ }
+ accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String())
+ if accessToken == "" {
+ lastErr = errors.New("oauth refresh response missing access_token")
+ continue
+ }
+ newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String())
+ expiresIn := gjson.GetBytes(respBody, "expires_in").Int()
+ expiresAt := ""
+ if expiresIn > 0 {
+ expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
+ }
+ return accessToken, newRefreshToken, expiresAt, nil
+ }
+
+ if lastErr != nil {
+ return "", "", "", lastErr
+ }
+ return "", "", "", errors.New("no available client_id for refresh_token exchange")
+}
+
+func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) {
+ if account == nil {
+ return
+ }
+ if account.Credentials == nil {
+ account.Credentials = make(map[string]any)
+ }
+ if strings.TrimSpace(accessToken) != "" {
+ account.Credentials["access_token"] = accessToken
+ }
+ if strings.TrimSpace(refreshToken) != "" {
+ account.Credentials["refresh_token"] = refreshToken
+ }
+ if strings.TrimSpace(expiresAt) != "" {
+ account.Credentials["expires_at"] = expiresAt
+ }
+ if strings.TrimSpace(sessionToken) != "" {
+ account.Credentials["session_token"] = sessionToken
+ }
+
+ if c.accountRepo != nil {
+ if err := c.accountRepo.Update(ctx, account); err != nil {
+ if c.debugEnabled() {
+ c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
+ }
+ }
+ }
+ c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken)
+}
+
+func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) {
+ if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 {
+ return
+ }
+ updates := make(map[string]any)
+ if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" {
+ updates["access_token"] = accessToken
+ updates["refresh_token"] = refreshToken
+ }
+ if strings.TrimSpace(sessionToken) != "" {
+ updates["session_token"] = sessionToken
+ }
+ if len(updates) == 0 {
+ return
+ }
+ if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() {
+ c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error()))
+ }
+}
+
+func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) {
+ if !c.debugEnabled() || account == nil {
+ return
+ }
+ if success {
+ c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
+ return
+ }
+ if err == nil {
+ c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason)
+ return
+ }
+ c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error()))
+}
+
+func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool {
+ if c == nil || c.tokenProvider == nil {
+ return false
+ }
+ if account != nil && account.Platform == PlatformSora {
+ return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider
+ }
+ return true
+}
+
+func (c *SoraDirectClient) logTokenSource(account *Account, source string) {
+ if !c.debugEnabled() || account == nil {
+ return
+ }
+ c.debugLogf(
+ "token_selected account_id=%d platform=%s account_type=%s source=%s",
+ account.ID,
+ account.Platform,
+ account.Type,
+ source,
+ )
}
func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header {
@@ -570,9 +1461,30 @@ func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header
}
func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) {
+ return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry)
+}
+
+func (c *SoraDirectClient) doRequestWithProxy(
+ ctx context.Context,
+ account *Account,
+ proxyURL string,
+ method,
+ urlStr string,
+ headers http.Header,
+ body io.Reader,
+ allowRetry bool,
+) ([]byte, http.Header, error) {
if strings.TrimSpace(urlStr) == "" {
return nil, nil, errors.New("empty upstream url")
}
+ proxyURL = strings.TrimSpace(proxyURL)
+ if proxyURL == "" {
+ proxyURL = c.resolveProxyURL(account)
+ }
+ if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil {
+ return nil, nil, cooldownErr
+ }
+ traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent"))
timeout := 0
if c != nil && c.cfg != nil {
timeout = c.cfg.Sora.Client.TimeoutSeconds
@@ -600,7 +1512,29 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
attempts := maxRetries + 1
+ authRecovered := false
+ authRecoverExtraAttemptGranted := false
+ challengeRetried := false
+ sawCFChallenge := false
+ var lastErr error
for attempt := 1; attempt <= attempts; attempt++ {
+ if c.debugEnabled() {
+ c.debugLogf(
+ "request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s",
+ traceID,
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ timeout,
+ len(bodyBytes),
+ proxyURL != "",
+ traceProxyKey,
+ traceUAHash,
+ formatSoraHeaders(headers),
+ )
+ }
+
var reader io.Reader
if bodyBytes != nil {
reader = bytes.NewReader(bodyBytes)
@@ -612,13 +1546,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
req.Header = headers.Clone()
start := time.Now()
- proxyURL := ""
- if account != nil && account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
- }
resp, err := c.doHTTP(req, proxyURL, account)
if err != nil {
+ lastErr = err
+ if c.debugEnabled() {
+ c.debugLogf(
+ "request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s",
+ traceID,
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ logredact.RedactText(err.Error()),
+ )
+ }
if attempt < attempts && allowRetry {
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts)
+ }
c.sleepRetry(attempt)
continue
}
@@ -632,24 +1577,119 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth
}
if c.cfg != nil && c.cfg.Sora.Client.Debug {
- log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start))
+ c.debugLogf(
+ "response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s",
+ traceID,
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ resp.StatusCode,
+ time.Since(start),
+ len(respBody),
+ formatSoraHeaders(resp.Header),
+ )
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
- upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody)
+ isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody)
+ if isCFChallenge {
+ sawCFChallenge = true
+ c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody)
+ if allowRetry && attempt < attempts && !challengeRetried {
+ challengeRetried = true
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
+ }
+ c.sleepRetry(attempt)
+ continue
+ }
+ }
+ if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil {
+ if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" {
+ headers.Set("Authorization", "Bearer "+recovered)
+ authRecovered = true
+ if attempt == attempts && !authRecoverExtraAttemptGranted {
+ attempts++
+ authRecoverExtraAttemptGranted = true
+ }
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode)
+ }
+ continue
+ } else if recoverErr != nil && c.debugEnabled() {
+ c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error()))
+ }
+ }
+ if c.debugEnabled() {
+ c.debugLogf(
+ "response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s",
+ traceID,
+ method,
+ sanitizeSoraLogURL(urlStr),
+ attempt,
+ attempts,
+ resp.StatusCode,
+ summarizeSoraResponseBody(respBody, 512),
+ )
+ }
+ upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr)
+ lastErr = upstreamErr
+ if isCFChallenge {
+ return nil, resp.Header, upstreamErr
+ }
if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) {
+ if c.debugEnabled() {
+ c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts)
+ }
c.sleepRetry(attempt)
continue
}
return nil, resp.Header, upstreamErr
}
+ if sawCFChallenge {
+ c.clearCloudflareChallengeCooldown(account, proxyURL)
+ }
return respBody, resp.Header, nil
}
+ if lastErr != nil {
+ return nil, nil, lastErr
+ }
return nil, nil, errors.New("upstream retries exhausted")
}
+func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool {
+ switch statusCode {
+ case http.StatusUnauthorized, http.StatusForbidden:
+ parsed, err := url.Parse(strings.TrimSpace(rawURL))
+ if err != nil {
+ return false
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return false
+ }
+ // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。
+ path := strings.ToLower(strings.TrimSpace(parsed.Path))
+ if path == "/api/auth/session" {
+ return false
+ }
+ return true
+ default:
+ return false
+ }
+}
+
func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
- enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint
+ if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
+ resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account)
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+ }
+
+ enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint
if c.httpUpstream != nil {
accountID := int64(0)
accountConcurrency := 0
@@ -670,9 +1710,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) {
time.Sleep(backoff)
}
-func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error {
+func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error {
msg := strings.TrimSpace(extractUpstreamErrorMessage(body))
msg = sanitizeUpstreamErrorMessage(msg)
+ if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") {
+ if hint := soraBaseURLNotFoundHint(requestURL); hint != "" {
+ msg = strings.TrimSpace(msg + " " + hint)
+ }
+ }
if msg == "" {
msg = truncateForLog(body, 256)
}
@@ -684,10 +1729,52 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b
}
}
-func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) {
+func normalizeSoraBaseURL(raw string) string {
+ trimmed := strings.TrimRight(strings.TrimSpace(raw), "/")
+ if trimmed == "" {
+ return ""
+ }
+ parsed, err := url.Parse(trimmed)
+ if err != nil || parsed.Scheme == "" || parsed.Host == "" {
+ return trimmed
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return trimmed
+ }
+ pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/")
+ switch pathVal {
+ case "", "/":
+ parsed.Path = "/backend"
+ case "/backend-api":
+ parsed.Path = "/backend"
+ }
+ return strings.TrimRight(parsed.String(), "/")
+}
+
+func soraBaseURLNotFoundHint(requestURL string) string {
+ parsed, err := url.Parse(strings.TrimSpace(requestURL))
+ if err != nil || parsed.Host == "" {
+ return ""
+ }
+ host := strings.ToLower(parsed.Hostname())
+ if host != "sora.chatgpt.com" && host != "chatgpt.com" {
+ return ""
+ }
+ pathVal := strings.TrimSpace(parsed.Path)
+ if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" {
+ return ""
+ }
+ return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)"
+}
+
+func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) {
reqID := uuid.NewString()
- userAgent := soraRandChoice(soraDesktopUserAgents)
- powToken := soraGetPowToken(userAgent)
+ userAgent = strings.TrimSpace(userAgent)
+ if userAgent == "" {
+ userAgent = c.taskUserAgent()
+ }
+ powToken := soraPowTokenGenerator(userAgent)
payload := map[string]any{
"p": powToken,
"flow": soraSentinelFlow,
@@ -708,7 +1795,7 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
}
urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req"
- respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
+ respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true)
if err != nil {
return "", err
}
@@ -724,16 +1811,6 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A
return sentinel, nil
}
-func soraRandChoice(items []string) string {
- if len(items) == 0 {
- return ""
- }
- soraRandMu.Lock()
- idx := soraRand.Intn(len(items))
- soraRandMu.Unlock()
- return items[idx]
-}
-
func soraGetPowToken(userAgent string) string {
configList := soraBuildPowConfig(userAgent)
seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64)
@@ -748,14 +1825,26 @@ func soraRandFloat() float64 {
return soraRand.Float64()
}
+func soraRandInt(max int) int {
+ if max <= 1 {
+ return 0
+ }
+ soraRandMu.Lock()
+ defer soraRandMu.Unlock()
+ return soraRand.Intn(max)
+}
+
func soraBuildPowConfig(userAgent string) []any {
- screen := soraRandChoice([]string{
- strconv.Itoa(1920 + 1080),
- strconv.Itoa(2560 + 1440),
- strconv.Itoa(1920 + 1200),
- strconv.Itoa(2560 + 1600),
- })
- screenVal, _ := strconv.Atoi(screen)
+ userAgent = strings.TrimSpace(userAgent)
+ if userAgent == "" && len(soraDesktopUserAgents) > 0 {
+ userAgent = soraDesktopUserAgents[0]
+ }
+ screenVal := soraStableChoiceInt([]int{
+ 1920 + 1080,
+ 2560 + 1440,
+ 1920 + 1200,
+ 2560 + 1600,
+ }, userAgent+"|screen")
perfMs := float64(time.Since(soraPerfStart).Milliseconds())
wallMs := float64(time.Now().UnixNano()) / 1e6
diff := wallMs - perfMs
@@ -765,32 +1854,47 @@ func soraBuildPowConfig(userAgent string) []any {
4294705152,
0,
userAgent,
- soraRandChoice(soraPowScripts),
- soraRandChoice(soraPowDPL),
+ soraStableChoice(soraPowScripts, userAgent+"|script"),
+ soraStableChoice(soraPowDPL, userAgent+"|dpl"),
"en-US",
"en-US,es-US,en,es",
0,
- soraRandChoice(soraPowNavigatorKeys),
- soraRandChoice(soraPowDocumentKeys),
- soraRandChoice(soraPowWindowKeys),
+ soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"),
+ soraStableChoice(soraPowDocumentKeys, userAgent+"|document"),
+ soraStableChoice(soraPowWindowKeys, userAgent+"|window"),
perfMs,
uuid.NewString(),
"",
- soraRandChoiceInt(soraPowCores),
+ soraStableChoiceInt(soraPowCores, userAgent+"|cores"),
diff,
}
}
-func soraRandChoiceInt(items []int) int {
+func soraStableChoice(items []string, seed string) string {
+ if len(items) == 0 {
+ return ""
+ }
+ idx := soraStableIndex(seed, len(items))
+ return items[idx]
+}
+
+func soraStableChoiceInt(items []int, seed string) int {
if len(items) == 0 {
return 0
}
- soraRandMu.Lock()
- idx := soraRand.Intn(len(items))
- soraRandMu.Unlock()
+ idx := soraStableIndex(seed, len(items))
return items[idx]
}
+func soraStableIndex(seed string, size int) int {
+ if size <= 0 {
+ return 0
+ }
+ h := fnv.New32a()
+ _, _ = h.Write([]byte(seed))
+ return int(h.Sum32() % uint32(size))
+}
+
func soraPowParseTime() string {
loc := time.FixedZone("EST", -5*3600)
return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)")
@@ -890,6 +1994,55 @@ func hexDecodeString(s string) ([]byte, error) {
return dst, err
}
+func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" {
+ return ctx
+ }
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano())
+ trace := &soraRequestTrace{
+ ID: "sora-" + soraHashForLog(seed),
+ ProxyKey: normalizeSoraProxyKey(proxyURL),
+ UAHash: soraHashForLog(strings.TrimSpace(userAgent)),
+ }
+ return context.WithValue(ctx, soraRequestTraceContextKey{}, trace)
+}
+
+func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) {
+ proxyKey := normalizeSoraProxyKey(proxyURL)
+ uaHash := soraHashForLog(strings.TrimSpace(userAgent))
+ traceID := ""
+ if ctx != nil {
+ if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil {
+ if strings.TrimSpace(trace.ID) != "" {
+ traceID = strings.TrimSpace(trace.ID)
+ }
+ if strings.TrimSpace(trace.ProxyKey) != "" {
+ proxyKey = strings.TrimSpace(trace.ProxyKey)
+ }
+ if strings.TrimSpace(trace.UAHash) != "" {
+ uaHash = strings.TrimSpace(trace.UAHash)
+ }
+ }
+ }
+ if traceID == "" {
+ traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano()))
+ }
+ return traceID, proxyKey, uaHash
+}
+
+func soraHashForLog(raw string) string {
+ h := fnv.New32a()
+ _, _ = h.Write([]byte(raw))
+ return fmt.Sprintf("%08x", h.Sum32())
+}
+
func sanitizeSoraLogURL(raw string) string {
parsed, err := url.Parse(raw)
if err != nil {
@@ -901,3 +2054,70 @@ func sanitizeSoraLogURL(raw string) string {
parsed.RawQuery = q.Encode()
return parsed.String()
}
+
+func (c *SoraDirectClient) debugEnabled() bool {
+ return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug
+}
+
+func (c *SoraDirectClient) debugLogf(format string, args ...any) {
+ if !c.debugEnabled() {
+ return
+ }
+ log.Printf("[SoraClient] "+format, args...)
+}
+
+func formatSoraHeaders(headers http.Header) string {
+ if len(headers) == 0 {
+ return "{}"
+ }
+ keys := make([]string, 0, len(headers))
+ for key := range headers {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ values := headers.Values(key)
+ if len(values) == 0 {
+ continue
+ }
+ val := strings.Join(values, ",")
+ if isSensitiveHeader(key) {
+ out[key] = "***"
+ continue
+ }
+ out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160)
+ }
+ encoded, err := json.Marshal(out)
+ if err != nil {
+ return "{}"
+ }
+ return string(encoded)
+}
+
+func isSensitiveHeader(key string) bool {
+ k := strings.ToLower(strings.TrimSpace(key))
+ switch k {
+ case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key":
+ return true
+ default:
+ return false
+ }
+}
+
+func summarizeSoraResponseBody(body []byte, maxLen int) string {
+ if len(body) == 0 {
+ return ""
+ }
+ var text string
+ if json.Valid(body) {
+ text = logredact.RedactJSON(body)
+ } else {
+ text = logredact.RedactText(string(body))
+ }
+ text = strings.TrimSpace(text)
+ if maxLen <= 0 || len(text) <= maxLen {
+ return text
+ }
+ return text[:maxLen] + "...(truncated)"
+}
diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go
index a6bf71cd..cffe8a35 100644
--- a/backend/internal/service/sora_client_test.go
+++ b/backend/internal/service/sora_client_test.go
@@ -4,9 +4,16 @@ package service
import (
"context"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "io"
"net/http"
"net/http/httptest"
+ "strings"
+ "sync/atomic"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
@@ -85,3 +92,984 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) {
require.Equal(t, "completed", status.Status)
require.Equal(t, []string{"https://example.com/a.png"}, status.URLs)
}
+
+func TestNormalizeSoraBaseURL(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ raw string
+ want string
+ }{
+ {
+ name: "empty",
+ raw: "",
+ want: "",
+ },
+ {
+ name: "append_backend_for_sora_host",
+ raw: "https://sora.chatgpt.com",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "convert_backend_api_to_backend",
+ raw: "https://sora.chatgpt.com/backend-api",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "keep_backend",
+ raw: "https://sora.chatgpt.com/backend",
+ want: "https://sora.chatgpt.com/backend",
+ },
+ {
+ name: "keep_custom_host",
+ raw: "https://example.com/custom-path",
+ want: "https://example.com/custom-path",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := normalizeSoraBaseURL(tt.raw)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen"))
+}
+
+func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) {
+ t.Parallel()
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+
+ err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen")
+ var upstreamErr *SoraUpstreamError
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url")
+
+ errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen")
+ require.ErrorAs(t, errNoHint, &upstreamErr)
+ require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url")
+}
+
+func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) {
+ t.Parallel()
+ headers := http.Header{}
+ headers.Set("Authorization", "Bearer secret-token")
+ headers.Set("openai-sentinel-token", "sentinel-secret")
+ headers.Set("X-Test", "ok")
+
+ out := formatSoraHeaders(headers)
+ require.Contains(t, out, `"Authorization":"***"`)
+ require.Contains(t, out, `Sentinel-Token":"***"`)
+ require.Contains(t, out, `"X-Test":"ok"`)
+ require.NotContains(t, out, "secret-token")
+ require.NotContains(t, out, "sentinel-secret")
+}
+
+func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`)
+ out := summarizeSoraResponseBody(body, 512)
+ require.Contains(t, out, `"access_token":"***"`)
+ require.NotContains(t, out, "abc123")
+}
+
+func TestSummarizeSoraResponseBody_Truncates(t *testing.T) {
+ t.Parallel()
+ body := []byte(strings.Repeat("x", 100))
+ out := summarizeSoraResponseBody(body, 10)
+ require.Contains(t, out, "(truncated)")
+}
+
+func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) {
+ t.Parallel()
+ cache := newOpenAITokenCacheStub()
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, provider)
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "sora-credential-token",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "sora-credential-token", token)
+ require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled))
+}
+
+func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) {
+ t.Parallel()
+ cache := newOpenAITokenCacheStub()
+ account := &Account{
+ ID: 2,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "sora-credential-token",
+ },
+ }
+ cache.tokens[OpenAITokenCacheKey(account)] = "provider-token"
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ UseOpenAITokenProvider: true,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, provider)
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "provider-token", token)
+ require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0))
+}
+
+func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token")
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "accessToken": "session-access-token",
+ "expires": "2099-01-01T00:00:00Z",
+ })
+ }))
+ defer server.Close()
+
+ origin := soraSessionAuthURL
+ soraSessionAuthURL = server.URL
+ defer func() { soraSessionAuthURL = origin }()
+
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+ account := &Account{
+ ID: 10,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "session_token": "session-token",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "session-access-token", token)
+ require.Equal(t, "session-access-token", account.GetCredential("access_token"))
+}
+
+func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+ require.Equal(t, "/oauth/token", r.URL.Path)
+ require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
+ require.NoError(t, r.ParseForm())
+ require.Equal(t, "refresh_token", r.FormValue("grant_type"))
+ require.Equal(t, "refresh-token-old", r.FormValue("refresh_token"))
+ require.NotEmpty(t, r.FormValue("client_id"))
+ require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri"))
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "refresh-access-token",
+ "refresh_token": "refresh-token-new",
+ "expires_in": 3600,
+ })
+ }))
+ defer server.Close()
+
+ origin := soraOAuthTokenURL
+ soraOAuthTokenURL = server.URL + "/oauth/token"
+ defer func() { soraOAuthTokenURL = origin }()
+
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+ account := &Account{
+ ID: 11,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "refresh-token-old",
+ },
+ }
+
+ token, err := client.getAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refresh-access-token", token)
+ require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token"))
+ require.NotNil(t, account.GetCredentialAsTime("expires_at"))
+}
+
+func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) {
+ t.Parallel()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodGet, r.Method)
+ require.Equal(t, "/nf/check", r.URL.Path)
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "rate_limit_and_credit_balance": map[string]any{
+ "estimated_num_videos_remaining": 0,
+ "rate_limit_reached": true,
+ },
+ })
+ }))
+ defer server.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: server.URL,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{
+ ID: 12,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "ok",
+ "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339),
+ },
+ }
+ err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"})
+ require.Error(t, err)
+ var upstreamErr *SoraUpstreamError
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
+}
+
+func TestShouldAttemptSoraTokenRecover(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen"))
+ require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token"))
+ require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen"))
+}
+
+type soraClientRequestCall struct {
+ Path string
+ UserAgent string
+ ProxyURL string
+}
+
+type soraClientRecordingUpstream struct {
+ calls []soraClientRequestCall
+}
+
+func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, errors.New("unexpected Do call")
+}
+
+func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) {
+ u.calls = append(u.calls, soraClientRequestCall{
+ Path: req.URL.Path,
+ UserAgent: req.Header.Get("User-Agent"),
+ ProxyURL: proxyURL,
+ })
+ switch req.URL.Path {
+ case "/backend-api/sentinel/req":
+ return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil
+ case "/backend/nf/create":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil
+ case "/backend/nf/create/storyboard":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil
+ case "/backend/uploads":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil
+ case "/backend/nf/check":
+ return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil
+ case "/backend/characters/upload":
+ return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil
+ case "/backend/project_y/cameos/in_progress/cameo-123":
+ return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil
+ case "/backend/project_y/file/upload":
+ return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil
+ case "/backend/characters/finalize":
+ return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil
+ case "/backend/project_y/post":
+ return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil
+ default:
+ return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil
+ }
+}
+
+func newSoraClientMockResponse(statusCode int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: statusCode,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) {
+ client := NewSoraDirectClient(&config.Config{}, nil, nil)
+ ua := client.taskUserAgent()
+ require.NotEmpty(t, ua)
+ allowed := append([]string{}, soraMobileUserAgents...)
+ allowed = append(allowed, soraDesktopUserAgents...)
+ require.Contains(t, allowed, ua)
+}
+
+func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) {
+ originPowTokenGenerator := soraPowTokenGenerator
+ soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
+ defer func() {
+ soraPowTokenGenerator = originPowTokenGenerator
+ }()
+
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ proxyID := int64(9)
+ account := &Account{
+ ID: 21,
+ Platform: PlatformSora,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ ProxyID: &proxyID,
+ Proxy: &Proxy{
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ },
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"})
+ require.NoError(t, err)
+ require.Equal(t, "task-123", taskID)
+ require.Len(t, upstream.calls, 2)
+
+ sentinelCall := upstream.calls[0]
+ createCall := upstream.calls[1]
+ require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path)
+ require.Equal(t, "/backend/nf/create", createCall.Path)
+ require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL)
+ require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL)
+ require.NotEmpty(t, sentinelCall.UserAgent)
+ require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent)
+}
+
+func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) {
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ proxyID := int64(3)
+ account := &Account{
+ ID: 31,
+ ProxyID: &proxyID,
+ Proxy: &Proxy{
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ },
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png")
+ require.NoError(t, err)
+ require.Equal(t, "upload-123", uploadID)
+ require.Len(t, upstream.calls, 1)
+ require.Equal(t, "/backend/uploads", upstream.calls[0].Path)
+ require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
+ require.NotEmpty(t, upstream.calls[0].UserAgent)
+}
+
+func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) {
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ proxyID := int64(7)
+ account := &Account{
+ ID: 41,
+ ProxyID: &proxyID,
+ Proxy: &Proxy{
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ },
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"})
+ require.NoError(t, err)
+ require.Len(t, upstream.calls, 1)
+ require.Equal(t, "/backend/nf/check", upstream.calls[0].Path)
+ require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL)
+ require.NotEmpty(t, upstream.calls[0].UserAgent)
+}
+
+func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) {
+ originPowTokenGenerator := soraPowTokenGenerator
+ soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
+ defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
+
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ account := &Account{
+ ID: 51,
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{
+ Prompt: "Shot 1:\nduration: 5sec\nScene: cat",
+ })
+ require.NoError(t, err)
+ require.Equal(t, "storyboard-123", taskID)
+ require.Len(t, upstream.calls, 2)
+ require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
+ require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path)
+}
+
+func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ switch r.URL.Path {
+ case "/nf/pending/v2":
+ _, _ = w.Write([]byte(`[]`))
+ case "/project_y/profile/drafts":
+ _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer server.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: server.URL,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{Credentials: map[string]any{"access_token": "token"}}
+
+ status, err := client.GetVideoTask(context.Background(), account, "task-1")
+ require.NoError(t, err)
+ require.Equal(t, "completed", status.Status)
+ require.Equal(t, "gen_1", status.GenerationID)
+ require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs)
+}
+
+func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) {
+ originPowTokenGenerator := soraPowTokenGenerator
+ soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" }
+ defer func() { soraPowTokenGenerator = originPowTokenGenerator }()
+
+ upstream := &soraClientRecordingUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ account := &Account{
+ ID: 52,
+ Credentials: map[string]any{
+ "access_token": "access-token",
+ "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
+ },
+ }
+
+ postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1")
+ require.NoError(t, err)
+ require.Equal(t, "s_post", postID)
+ require.Len(t, upstream.calls, 2)
+ require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path)
+ require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path)
+}
+
+type soraClientFallbackUpstream struct {
+ doWithTLSCalls int32
+ respBody string
+ respStatusCode int
+ err error
+}
+
+func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, errors.New("unexpected Do call")
+}
+
+func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
+ atomic.AddInt32(&u.doWithTLSCalls, 1)
+ if u.err != nil {
+ return nil, u.err
+ }
+ statusCode := u.respStatusCode
+ if statusCode <= 0 {
+ statusCode = http.StatusOK
+ }
+ body := u.respBody
+ if body == "" {
+ body = `{"ok":true}`
+ }
+ return newSoraClientMockResponse(statusCode, body), nil
+}
+
+func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) {
+ var captured soraCurlCFFISidecarRequest
+ sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, http.MethodPost, r.Method)
+ require.Equal(t, "/request", r.URL.Path)
+ raw, err := io.ReadAll(r.Body)
+ require.NoError(t, err)
+ require.NoError(t, json.Unmarshal(raw, &captured))
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "status_code": http.StatusOK,
+ "headers": map[string]any{
+ "Content-Type": "application/json",
+ "X-Sidecar": []string{"yes"},
+ },
+ "body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)),
+ })
+ }))
+ defer sidecar.Close()
+
+ upstream := &soraClientFallbackUpstream{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ BaseURL: sidecar.URL,
+ Impersonate: "chrome131",
+ TimeoutSeconds: 15,
+ SessionReuseEnabled: true,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar"))
+ require.NoError(t, err)
+ req.Header.Set("User-Agent", "test-ua")
+
+ resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1})
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ require.JSONEq(t, `{"ok":true}`, string(body))
+ require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
+ require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL)
+ require.NotEmpty(t, captured.SessionKey)
+ require.Equal(t, "chrome131", captured.Impersonate)
+ require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL)
+ decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64)
+ require.NoError(t, err)
+ require.Equal(t, "hello-sidecar", string(decodedReqBody))
+}
+
+func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) {
+ sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadGateway)
+ _, _ = w.Write([]byte(`{"error":"boom"}`))
+ }))
+ defer sidecar.Close()
+
+ upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ BaseURL: sidecar.URL,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
+ require.NoError(t, err)
+
+ _, err = client.doHTTP(req, "", &Account{ID: 2})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "sora curl_cffi sidecar")
+ require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls))
+}
+
+func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) {
+ upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: false,
+ BaseURL: "http://127.0.0.1:18080",
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, upstream, nil)
+ req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
+ require.NoError(t, err)
+
+ resp, err := client.doHTTP(req, "", &Account{ID: 3})
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ require.JSONEq(t, `{"legacy":true}`, string(body))
+ require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls))
+}
+
+func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) {
+ require.Nil(t, convertSidecarHeaderValue(nil))
+ require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"}))
+}
+
+func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) {
+ var captured []soraCurlCFFISidecarRequest
+ sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ raw, err := io.ReadAll(r.Body)
+ require.NoError(t, err)
+ var reqPayload soraCurlCFFISidecarRequest
+ require.NoError(t, json.Unmarshal(raw, &reqPayload))
+ captured = append(captured, reqPayload)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "status_code": http.StatusOK,
+ "headers": map[string]any{
+ "Content-Type": "application/json",
+ },
+ "body": `{"ok":true}`,
+ })
+ }))
+ defer sidecar.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ BaseURL: sidecar.URL,
+ SessionReuseEnabled: true,
+ SessionTTLSeconds: 3600,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{ID: 1001}
+
+ req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
+ require.NoError(t, err)
+ _, err = client.doHTTP(req1, "http://127.0.0.1:18080", account)
+ require.NoError(t, err)
+
+ req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil)
+ require.NoError(t, err)
+ _, err = client.doHTTP(req2, "http://127.0.0.1:18080", account)
+ require.NoError(t, err)
+
+ require.Len(t, captured, 2)
+ require.NotEmpty(t, captured[0].SessionKey)
+ require.Equal(t, captured[0].SessionKey, captured[1].SessionKey)
+}
+
+func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) {
+ var sidecarCalls int32
+ sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ atomic.AddInt32(&sidecarCalls, 1)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "status_code": http.StatusForbidden,
+ "headers": map[string]any{
+ "cf-ray": "9d05d73dec4d8c8e-GRU",
+ "content-type": "text/html",
+ },
+ "body": `Just a moment...`,
+ })
+ }))
+ defer sidecar.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ MaxRetries: 3,
+ CloudflareChallengeCooldownSeconds: 60,
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ BaseURL: sidecar.URL,
+ Impersonate: "chrome131",
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ headers := http.Header{}
+
+ _, _, err := client.doRequestWithProxy(
+ context.Background(),
+ &Account{ID: 99},
+ "http://127.0.0.1:18080",
+ http.MethodGet,
+ "https://sora.chatgpt.com/backend/me",
+ headers,
+ nil,
+ true,
+ )
+ require.Error(t, err)
+ var upstreamErr *SoraUpstreamError
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode)
+ require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry")
+
+ _, _, err = client.doRequestWithProxy(
+ context.Background(),
+ &Account{ID: 99},
+ "http://127.0.0.1:18080",
+ http.MethodGet,
+ "https://sora.chatgpt.com/backend/me",
+ headers,
+ nil,
+ true,
+ )
+ require.Error(t, err)
+ require.ErrorAs(t, err, &upstreamErr)
+ require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode)
+ require.Contains(t, upstreamErr.Message, "cooling down")
+ require.Contains(t, upstreamErr.Message, "cf-ray")
+ require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request")
+}
+
+func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) {
+ var sidecarCalls int32
+ sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ call := atomic.AddInt32(&sidecarCalls, 1)
+ if call == 1 {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "status_code": http.StatusForbidden,
+ "headers": map[string]any{
+ "cf-ray": "9d05d73dec4d8c8e-GRU",
+ "content-type": "text/html",
+ },
+ "body": `Just a moment...`,
+ })
+ return
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "status_code": http.StatusOK,
+ "headers": map[string]any{
+ "content-type": "application/json",
+ },
+ "body": `{"ok":true}`,
+ })
+ }))
+ defer sidecar.Close()
+
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ BaseURL: "https://sora.chatgpt.com/backend",
+ MaxRetries: 3,
+ CloudflareChallengeCooldownSeconds: 60,
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ BaseURL: sidecar.URL,
+ Impersonate: "chrome131",
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ headers := http.Header{}
+ account := &Account{ID: 109}
+ proxyURL := "http://127.0.0.1:18080"
+
+ body, _, err := client.doRequestWithProxy(
+ context.Background(),
+ account,
+ proxyURL,
+ http.MethodGet,
+ "https://sora.chatgpt.com/backend/me",
+ headers,
+ nil,
+ true,
+ )
+ require.NoError(t, err)
+ require.Contains(t, string(body), `"ok":true`)
+ require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls))
+
+ _, _, err = client.doRequestWithProxy(
+ context.Background(),
+ account,
+ proxyURL,
+ http.MethodGet,
+ "https://sora.chatgpt.com/backend/me",
+ headers,
+ nil,
+ true,
+ )
+ require.NoError(t, err)
+ require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds")
+}
+
+func TestSoraComputeChallengeCooldownSeconds(t *testing.T) {
+ require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3))
+ require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1))
+ require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2))
+ require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4))
+ require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4")
+ require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s")
+}
+
+func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ CloudflareChallengeCooldownSeconds: 10,
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{ID: 201}
+ proxyURL := "http://127.0.0.1:18080"
+
+ client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil)
+ client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil)
+
+ key := soraAccountProxyKey(account, proxyURL)
+ entry, ok := client.challengeCooldowns[key]
+ require.True(t, ok)
+ require.Equal(t, 2, entry.ConsecutiveChallenges)
+ require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay)
+ remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds())
+ require.GreaterOrEqual(t, remain, 19)
+}
+
+func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ SessionReuseEnabled: true,
+ SessionTTLSeconds: 3600,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080"))
+ require.Empty(t, client.sidecarSessions)
+}
+
+func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ SessionReuseEnabled: true,
+ SessionTTLSeconds: 3600,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{ID: 123}
+ key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
+ client.sidecarSessions[key] = soraSidecarSessionEntry{
+ SessionKey: "sora-expired",
+ ExpiresAt: time.Now().Add(-time.Minute),
+ LastUsedAt: time.Now().Add(-2 * time.Minute),
+ }
+
+ sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
+ require.NotEmpty(t, sessionKey)
+ require.NotEqual(t, "sora-expired", sessionKey)
+ require.Len(t, client.sidecarSessions, 1)
+}
+
+func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) {
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{
+ Enabled: true,
+ SessionReuseEnabled: true,
+ SessionTTLSeconds: 0,
+ },
+ },
+ },
+ }
+ client := NewSoraDirectClient(cfg, nil, nil)
+ account := &Account{ID: 456}
+
+ first := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
+ second := client.sidecarSessionKey(account, "http://127.0.0.1:18080")
+ require.NotEmpty(t, first)
+ require.Equal(t, first, second)
+
+ key := soraAccountProxyKey(account, "http://127.0.0.1:18080")
+ entry, ok := client.sidecarSessions[key]
+ require.True(t, ok)
+ require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour)))
+}
diff --git a/backend/internal/service/sora_curl_cffi_sidecar.go b/backend/internal/service/sora_curl_cffi_sidecar.go
new file mode 100644
index 00000000..40f5c017
--- /dev/null
+++ b/backend/internal/service/sora_curl_cffi_sidecar.go
@@ -0,0 +1,260 @@
+package service
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/util/logredact"
+)
+
+const soraCurlCFFISidecarDefaultTimeoutSeconds = 60
+
+type soraCurlCFFISidecarRequest struct {
+ Method string `json:"method"`
+ URL string `json:"url"`
+ Headers map[string][]string `json:"headers,omitempty"`
+ BodyBase64 string `json:"body_base64,omitempty"`
+ ProxyURL string `json:"proxy_url,omitempty"`
+ SessionKey string `json:"session_key,omitempty"`
+ Impersonate string `json:"impersonate,omitempty"`
+ TimeoutSeconds int `json:"timeout_seconds,omitempty"`
+}
+
+type soraCurlCFFISidecarResponse struct {
+ StatusCode int `json:"status_code"`
+ Status int `json:"status"`
+ Headers map[string]any `json:"headers"`
+ BodyBase64 string `json:"body_base64"`
+ Body string `json:"body"`
+ Error string `json:"error"`
+}
+
+func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
+ if req == nil || req.URL == nil {
+ return nil, errors.New("request url is nil")
+ }
+ if c == nil || c.cfg == nil {
+ return nil, errors.New("sora curl_cffi sidecar config is nil")
+ }
+ if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled {
+ return nil, errors.New("sora curl_cffi sidecar is disabled")
+ }
+ endpoint := c.curlCFFISidecarEndpoint()
+ if endpoint == "" {
+ return nil, errors.New("sora curl_cffi sidecar base_url is empty")
+ }
+
+ bodyBytes, err := readAndRestoreRequestBody(req)
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err)
+ }
+
+ headers := make(map[string][]string, len(req.Header)+1)
+ for key, vals := range req.Header {
+ copied := make([]string, len(vals))
+ copy(copied, vals)
+ headers[key] = copied
+ }
+ if strings.TrimSpace(req.Host) != "" {
+ if _, ok := headers["Host"]; !ok {
+ headers["Host"] = []string{req.Host}
+ }
+ }
+
+ payload := soraCurlCFFISidecarRequest{
+ Method: req.Method,
+ URL: req.URL.String(),
+ Headers: headers,
+ ProxyURL: strings.TrimSpace(proxyURL),
+ SessionKey: c.sidecarSessionKey(account, proxyURL),
+ Impersonate: c.curlCFFIImpersonate(),
+ TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(),
+ }
+ if len(bodyBytes) > 0 {
+ payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes)
+ }
+
+ encoded, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err)
+ }
+
+ sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded))
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err)
+ }
+ sidecarReq.Header.Set("Content-Type", "application/json")
+ sidecarReq.Header.Set("Accept", "application/json")
+
+ httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second}
+ sidecarResp, err := httpClient.Do(sidecarReq)
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err)
+ }
+ defer func() {
+ _ = sidecarResp.Body.Close()
+ }()
+
+ sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20))
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err)
+ }
+ if sidecarResp.StatusCode != http.StatusOK {
+ redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512)
+ return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted)
+ }
+
+ var payloadResp soraCurlCFFISidecarResponse
+ if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err)
+ }
+ if msg := strings.TrimSpace(payloadResp.Error); msg != "" {
+ return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg)
+ }
+ statusCode := payloadResp.StatusCode
+ if statusCode <= 0 {
+ statusCode = payloadResp.Status
+ }
+ if statusCode <= 0 {
+ return nil, errors.New("sora curl_cffi sidecar response missing status code")
+ }
+
+ responseBody := []byte(payloadResp.Body)
+ if strings.TrimSpace(payloadResp.BodyBase64) != "" {
+ decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64)
+ if err != nil {
+ return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err)
+ }
+ responseBody = decoded
+ }
+
+ respHeaders := make(http.Header)
+ for key, rawVal := range payloadResp.Headers {
+ for _, v := range convertSidecarHeaderValue(rawVal) {
+ respHeaders.Add(key, v)
+ }
+ }
+
+ return &http.Response{
+ StatusCode: statusCode,
+ Header: respHeaders,
+ Body: io.NopCloser(bytes.NewReader(responseBody)),
+ ContentLength: int64(len(responseBody)),
+ Request: req,
+ }, nil
+}
+
+func readAndRestoreRequestBody(req *http.Request) ([]byte, error) {
+ if req == nil || req.Body == nil {
+ return nil, nil
+ }
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+ req.ContentLength = int64(len(bodyBytes))
+ return bodyBytes, nil
+}
+
+func (c *SoraDirectClient) curlCFFISidecarEndpoint() string {
+ if c == nil || c.cfg == nil {
+ return ""
+ }
+ raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL)
+ if raw == "" {
+ return ""
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" {
+ return raw
+ }
+ if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" {
+ parsed.Path = "/request"
+ }
+ return parsed.String()
+}
+
+func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int {
+ if c == nil || c.cfg == nil {
+ return soraCurlCFFISidecarDefaultTimeoutSeconds
+ }
+ timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds
+ if timeoutSeconds <= 0 {
+ return soraCurlCFFISidecarDefaultTimeoutSeconds
+ }
+ return timeoutSeconds
+}
+
+func (c *SoraDirectClient) curlCFFIImpersonate() string {
+ if c == nil || c.cfg == nil {
+ return "chrome131"
+ }
+ impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate)
+ if impersonate == "" {
+ return "chrome131"
+ }
+ return impersonate
+}
+
+func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool {
+ if c == nil || c.cfg == nil {
+ return true
+ }
+ return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled
+}
+
+func (c *SoraDirectClient) sidecarSessionTTLSeconds() int {
+ if c == nil || c.cfg == nil {
+ return 3600
+ }
+ ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds
+ if ttl < 0 {
+ return 3600
+ }
+ return ttl
+}
+
+func convertSidecarHeaderValue(raw any) []string {
+ switch val := raw.(type) {
+ case nil:
+ return nil
+ case string:
+ if strings.TrimSpace(val) == "" {
+ return nil
+ }
+ return []string{val}
+ case []any:
+ out := make([]string, 0, len(val))
+ for _, item := range val {
+ s := strings.TrimSpace(fmt.Sprint(item))
+ if s != "" {
+ out = append(out, s)
+ }
+ }
+ return out
+ case []string:
+ out := make([]string, 0, len(val))
+ for _, item := range val {
+ if strings.TrimSpace(item) != "" {
+ out = append(out, item)
+ }
+ }
+ return out
+ default:
+ s := strings.TrimSpace(fmt.Sprint(val))
+ if s == "" {
+ return nil
+ }
+ return []string{s}
+ }
+}
diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go
index d7ff297c..ac29ae0d 100644
--- a/backend/internal/service/sora_gateway_service.go
+++ b/backend/internal/service/sora_gateway_service.go
@@ -8,10 +8,12 @@ import (
"fmt"
"io"
"log"
+ "math"
"mime"
"net"
"net/http"
"net/url"
+ "regexp"
"strconv"
"strings"
"time"
@@ -23,6 +25,9 @@ import (
const soraImageInputMaxBytes = 20 << 20
const soraImageInputMaxRedirects = 3
const soraImageInputTimeout = 20 * time.Second
+const soraVideoInputMaxBytes = 200 << 20
+const soraVideoInputMaxRedirects = 3
+const soraVideoInputTimeout = 60 * time.Second
var soraImageSizeMap = map[string]string{
"gpt-image": "360",
@@ -61,6 +66,36 @@ type SoraGatewayService struct {
cfg *config.Config
}
+type soraWatermarkOptions struct {
+ Enabled bool
+ ParseMethod string
+ ParseURL string
+ ParseToken string
+ FallbackOnFailure bool
+ DeletePost bool
+}
+
+type soraCharacterOptions struct {
+ SetPublic bool
+ DeleteAfterGenerate bool
+}
+
+type soraCharacterFlowResult struct {
+ CameoID string
+ CharacterID string
+ Username string
+ DisplayName string
+}
+
+var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`)
+var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`)
+var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`)
+var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`)
+
+type soraPreflightChecker interface {
+ PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error
+}
+
func NewSoraGatewayService(
soraClient SoraClient,
mediaStorage *SoraMediaStorage,
@@ -112,29 +147,133 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream)
return nil, fmt.Errorf("unsupported model: %s", reqModel)
}
- if modelCfg.Type == "prompt_enhance" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream)
- return nil, fmt.Errorf("prompt-enhance not supported")
- }
-
prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody)
- if strings.TrimSpace(prompt) == "" {
+ prompt = strings.TrimSpace(prompt)
+ imageInput = strings.TrimSpace(imageInput)
+ videoInput = strings.TrimSpace(videoInput)
+ remixTargetID = strings.TrimSpace(remixTargetID)
+
+ if videoInput != "" && modelCfg.Type != "video" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream)
+ return nil, errors.New("video input only supports video models")
+ }
+ if videoInput != "" && imageInput != "" {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream)
+ return nil, errors.New("image input and video input cannot be used together")
+ }
+ characterOnly := videoInput != "" && prompt == ""
+ if modelCfg.Type == "prompt_enhance" && prompt == "" {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
return nil, errors.New("prompt is required")
}
- if strings.TrimSpace(videoInput) != "" {
- s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream)
- return nil, errors.New("video input not supported")
+ if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream)
+ return nil, errors.New("prompt is required")
}
reqCtx, cancel := s.withSoraTimeout(ctx, reqStream)
if cancel != nil {
defer cancel()
}
+ if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly {
+ if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ }
+
+ if modelCfg.Type == "prompt_enhance" {
+ enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS)
+ if err != nil {
+ return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream)
+ }
+ content := strings.TrimSpace(enhancedPrompt)
+ if content == "" {
+ content = prompt
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel))
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: reqModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
+
+ characterOpts := parseSoraCharacterOptions(reqBody)
+ watermarkOpts := parseSoraWatermarkOptions(reqBody)
+ var characterResult *soraCharacterFlowResult
+ if videoInput != "" {
+ videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput)
+ if videoErr != nil {
+ s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream)
+ return nil, videoErr
+ }
+ characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts)
+ if videoErr != nil {
+ return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream)
+ }
+ if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly {
+ characterID := strings.TrimSpace(characterResult.CharacterID)
+ defer func() {
+ cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancelCleanup()
+ if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil {
+ log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err)
+ }
+ }()
+ }
+ if characterOnly {
+ content := "角色创建成功"
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username))
+ }
+ var firstTokenMs *int
+ if clientStream {
+ ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime)
+ if streamErr != nil {
+ return nil, streamErr
+ }
+ firstTokenMs = ms
+ } else if c != nil {
+ resp := buildSoraNonStreamResponse(content, reqModel)
+ if characterResult != nil {
+ resp["character_id"] = characterResult.CharacterID
+ resp["cameo_id"] = characterResult.CameoID
+ resp["character_username"] = characterResult.Username
+ resp["character_display_name"] = characterResult.DisplayName
+ }
+ c.JSON(http.StatusOK, resp)
+ }
+ return &ForwardResult{
+ RequestID: "",
+ Model: reqModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ Usage: ClaudeUsage{},
+ MediaType: "prompt",
+ }, nil
+ }
+ if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" {
+ prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt)
+ }
+ }
var imageData []byte
imageFilename := ""
- if strings.TrimSpace(imageInput) != "" {
+ if imageInput != "" {
decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput)
if err != nil {
s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream)
@@ -164,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
MediaID: mediaID,
})
case "video":
- taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
- Prompt: prompt,
- Orientation: modelCfg.Orientation,
- Frames: modelCfg.Frames,
- Model: modelCfg.Model,
- Size: modelCfg.Size,
- MediaID: mediaID,
- RemixTargetID: remixTargetID,
- })
+ if remixTargetID == "" && isSoraStoryboardPrompt(prompt) {
+ taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{
+ Prompt: formatSoraStoryboardPrompt(prompt),
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ MediaID: mediaID,
+ })
+ } else {
+ taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{
+ Prompt: prompt,
+ Orientation: modelCfg.Orientation,
+ Frames: modelCfg.Frames,
+ Model: modelCfg.Model,
+ Size: modelCfg.Size,
+ MediaID: mediaID,
+ RemixTargetID: remixTargetID,
+ CameoIDs: extractSoraCameoIDs(reqBody),
+ })
+ }
default:
err = fmt.Errorf("unsupported model type: %s", modelCfg.Type)
}
@@ -185,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
}
var mediaURLs []string
+ videoGenerationID := ""
mediaType := modelCfg.Type
imageCount := 0
imageSize := ""
@@ -198,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
imageCount = len(urls)
imageSize = soraImageSizeFromModel(reqModel)
case "video":
- urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream)
+ videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream)
if pollErr != nil {
return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream)
}
- mediaURLs = urls
+ if videoStatus != nil {
+ mediaURLs = videoStatus.URLs
+ videoGenerationID = strings.TrimSpace(videoStatus.GenerationID)
+ }
default:
mediaType = "prompt"
}
+ watermarkPostID := ""
+ if modelCfg.Type == "video" && watermarkOpts.Enabled {
+ watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts)
+ if watermarkErr != nil {
+ if !watermarkOpts.FallbackOnFailure {
+ return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream)
+ }
+ log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr)
+ } else if strings.TrimSpace(watermarkURL) != "" {
+ mediaURLs = []string{strings.TrimSpace(watermarkURL)}
+ watermarkPostID = strings.TrimSpace(postID)
+ }
+ }
+
finalURLs := s.normalizeSoraMediaURLs(mediaURLs)
if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() {
stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs)
@@ -217,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun
finalURLs = s.normalizeSoraMediaURLs(stored)
}
}
+ if watermarkPostID != "" && watermarkOpts.DeletePost {
+ if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil {
+ log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr)
+ }
+ }
content := buildSoraContent(mediaType, finalURLs)
var firstTokenMs *int
@@ -265,9 +439,270 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (
return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second)
}
+func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions {
+ opts := soraWatermarkOptions{
+ Enabled: parseBoolWithDefault(body, "watermark_free", false),
+ ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))),
+ ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")),
+ ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")),
+ FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true),
+ DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false),
+ }
+ if opts.ParseMethod == "" {
+ opts.ParseMethod = "third_party"
+ }
+ return opts
+}
+
+func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions {
+ return soraCharacterOptions{
+ SetPublic: parseBoolWithDefault(body, "character_set_public", true),
+ DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true),
+ }
+}
+
+func parseBoolWithDefault(body map[string]any, key string, def bool) bool {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ switch typed := val.(type) {
+ case bool:
+ return typed
+ case int:
+ return typed != 0
+ case int32:
+ return typed != 0
+ case int64:
+ return typed != 0
+ case float64:
+ return typed != 0
+ case string:
+ typed = strings.ToLower(strings.TrimSpace(typed))
+ if typed == "true" || typed == "1" || typed == "yes" {
+ return true
+ }
+ if typed == "false" || typed == "0" || typed == "no" {
+ return false
+ }
+ }
+ return def
+}
+
+func parseStringWithDefault(body map[string]any, key, def string) string {
+ if body == nil {
+ return def
+ }
+ val, ok := body[key]
+ if !ok {
+ return def
+ }
+ if str, ok := val.(string); ok {
+ return str
+ }
+ return def
+}
+
+func extractSoraCameoIDs(body map[string]any) []string {
+ if body == nil {
+ return nil
+ }
+ raw, ok := body["cameo_ids"]
+ if !ok {
+ return nil
+ }
+ switch typed := raw.(type) {
+ case []string:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ item = strings.TrimSpace(item)
+ if item != "" {
+ out = append(out, item)
+ }
+ }
+ return out
+ case []any:
+ out := make([]string, 0, len(typed))
+ for _, item := range typed {
+ str, ok := item.(string)
+ if !ok {
+ continue
+ }
+ str = strings.TrimSpace(str)
+ if str != "" {
+ out = append(out, str)
+ }
+ }
+ return out
+ default:
+ return nil
+ }
+}
+
+func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) {
+ cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData)
+ if err != nil {
+ return nil, err
+ }
+
+ cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ return nil, err
+ }
+ username := processSoraCharacterUsername(cameoStatus.UsernameHint)
+ displayName := strings.TrimSpace(cameoStatus.DisplayNameHint)
+ if displayName == "" {
+ displayName = "Character"
+ }
+ profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL)
+ if profileAssetURL == "" {
+ return nil, errors.New("profile asset url not found in cameo status")
+ }
+
+ avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL)
+ if err != nil {
+ return nil, err
+ }
+ assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData)
+ if err != nil {
+ return nil, err
+ }
+ instructionSet := cameoStatus.InstructionSetHint
+ if instructionSet == nil {
+ instructionSet = cameoStatus.InstructionSet
+ }
+
+ characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{
+ CameoID: strings.TrimSpace(cameoID),
+ Username: username,
+ DisplayName: displayName,
+ ProfileAssetPointer: assetPointer,
+ InstructionSet: instructionSet,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.SetPublic {
+ if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil {
+ return nil, err
+ }
+ }
+
+ return &soraCharacterFlowResult{
+ CameoID: strings.TrimSpace(cameoID),
+ CharacterID: strings.TrimSpace(characterID),
+ Username: strings.TrimSpace(username),
+ DisplayName: displayName,
+ }, nil
+}
+
+func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ timeout := 10 * time.Minute
+ interval := 5 * time.Second
+ maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds()))
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+
+ var lastErr error
+ consecutiveErrors := 0
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID)
+ if err != nil {
+ lastErr = err
+ consecutiveErrors++
+ if consecutiveErrors >= 3 {
+ break
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ consecutiveErrors = 0
+ if status == nil {
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ continue
+ }
+ currentStatus := strings.ToLower(strings.TrimSpace(status.Status))
+ statusMessage := strings.TrimSpace(status.StatusMessage)
+ if currentStatus == "failed" {
+ if statusMessage == "" {
+ statusMessage = "character creation failed"
+ }
+ return nil, errors.New(statusMessage)
+ }
+ if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" {
+ return status, nil
+ }
+ if attempt < maxAttempts-1 {
+ if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil {
+ return nil, sleepErr
+ }
+ }
+ }
+ if lastErr != nil {
+ return nil, fmt.Errorf("poll cameo status failed: %w", lastErr)
+ }
+ return nil, errors.New("cameo processing timeout")
+}
+
+func processSoraCharacterUsername(usernameHint string) string {
+ usernameHint = strings.TrimSpace(usernameHint)
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ if strings.Contains(usernameHint, ".") {
+ parts := strings.Split(usernameHint, ".")
+ usernameHint = strings.TrimSpace(parts[len(parts)-1])
+ }
+ if usernameHint == "" {
+ usernameHint = "character"
+ }
+ return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100)
+}
+
+func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) {
+ generationID = strings.TrimSpace(generationID)
+ if generationID == "" {
+ return "", "", errors.New("generation id is required for watermark-free mode")
+ }
+ postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID)
+ if err != nil {
+ return "", "", err
+ }
+ postID = strings.TrimSpace(postID)
+ if postID == "" {
+ return "", "", errors.New("watermark-free publish returned empty post id")
+ }
+
+ switch opts.ParseMethod {
+ case "custom":
+ urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID)
+ if parseErr != nil {
+ return "", postID, parseErr
+ }
+ return strings.TrimSpace(urlVal), postID, nil
+ case "", "third_party":
+ return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil
+ default:
+ return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod)
+ }
+}
+
func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
- case 401, 402, 403, 429, 529:
+ case 401, 402, 403, 404, 429, 529:
return true
default:
return statusCode >= 500
@@ -434,7 +869,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType,
}
if stream {
flusher, _ := c.Writer.(http.Flusher)
- errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
+ errorData := map[string]any{
+ "error": map[string]string{
+ "type": errType,
+ "message": message,
+ },
+ }
+ jsonBytes, err := json.Marshal(errorData)
+ if err != nil {
+ _ = c.Error(err)
+ return
+ }
+ errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
_, _ = fmt.Fprint(c.Writer, errorEvent)
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
if flusher != nil {
@@ -460,7 +906,15 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account
s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body)
}
if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) {
- return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode}
+ var responseHeaders http.Header
+ if upstreamErr.Headers != nil {
+ responseHeaders = upstreamErr.Headers.Clone()
+ }
+ return &UpstreamFailoverError{
+ StatusCode: upstreamErr.StatusCode,
+ ResponseBody: upstreamErr.Body,
+ ResponseHeaders: responseHeaders,
+ }
}
msg := upstreamErr.Message
if override := soraProErrorMessage(model, msg); override != "" {
@@ -505,7 +959,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context,
return nil, errors.New("sora image generation timeout")
}
-func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) {
+func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) {
interval := s.pollInterval()
maxAttempts := s.pollMaxAttempts()
lastPing := time.Now()
@@ -516,7 +970,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context,
}
switch strings.ToLower(status.Status) {
case "completed", "succeeded":
- return status.URLs, nil
+ return status, nil
case "failed":
if status.ErrorMsg != "" {
return nil, errors.New(status.ErrorMsg)
@@ -620,7 +1074,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
return "", "", "", ""
}
if v, ok := body["remix_target_id"].(string); ok {
- remixTargetID = v
+ remixTargetID = strings.TrimSpace(v)
}
if v, ok := body["image"].(string); ok {
imageInput = v
@@ -661,6 +1115,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi
prompt = builder.String()
}
}
+ if remixTargetID == "" {
+ remixTargetID = extractRemixTargetIDFromPrompt(prompt)
+ }
+ prompt = cleanRemixLinkFromPrompt(prompt)
return prompt, imageInput, videoInput, remixTargetID
}
@@ -708,6 +1166,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string)
}
}
+func isSoraStoryboardPrompt(prompt string) bool {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return false
+ }
+ return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1
+}
+
+func formatSoraStoryboardPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1)
+ if len(matches) == 0 {
+ return prompt
+ }
+ firstBracketPos := strings.Index(prompt, "[")
+ instructions := ""
+ if firstBracketPos > 0 {
+ instructions = strings.TrimSpace(prompt[:firstBracketPos])
+ }
+ shots := make([]string, 0, len(matches))
+ for i, match := range matches {
+ if len(match) < 3 {
+ continue
+ }
+ duration := strings.TrimSpace(match[1])
+ scene := strings.TrimSpace(match[2])
+ if scene == "" {
+ continue
+ }
+ shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene))
+ }
+ if len(shots) == 0 {
+ return prompt
+ }
+ timeline := strings.Join(shots, "\n\n")
+ if instructions == "" {
+ return timeline
+ }
+ return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions)
+}
+
+func extractRemixTargetIDFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return ""
+ }
+ return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt))
+}
+
+func cleanRemixLinkFromPrompt(prompt string) string {
+ prompt = strings.TrimSpace(prompt)
+ if prompt == "" {
+ return prompt
+ }
+ cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "")
+ cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "")
+ cleaned = strings.Join(strings.Fields(cleaned), " ")
+ return strings.TrimSpace(cleaned)
+}
+
func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
@@ -720,7 +1241,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
}
meta := parts[0]
payload := parts[1]
- decoded, err := base64.StdEncoding.DecodeString(payload)
+ decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes)
if err != nil {
return nil, "", err
}
@@ -739,15 +1260,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er
if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
return downloadSoraImageInput(ctx, raw)
}
- decoded, err := base64.StdEncoding.DecodeString(raw)
+ decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes)
if err != nil {
return nil, "", errors.New("invalid base64 image")
}
return decoded, "image.png", nil
}
+func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) {
+ raw := strings.TrimSpace(input)
+ if raw == "" {
+ return nil, errors.New("empty video input")
+ }
+ if strings.HasPrefix(raw, "data:") {
+ parts := strings.SplitN(raw, ",", 2)
+ if len(parts) != 2 {
+ return nil, errors.New("invalid video data url")
+ }
+ decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+ }
+ if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") {
+ return downloadSoraVideoInput(ctx, raw)
+ }
+ decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes)
+ if err != nil {
+ return nil, errors.New("invalid base64 video")
+ }
+ if len(decoded) == 0 {
+ return nil, errors.New("empty video data")
+ }
+ return decoded, nil
+}
+
func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) {
- parsed, err := validateSoraImageURL(rawURL)
+ parsed, err := validateSoraRemoteURL(rawURL)
if err != nil {
return nil, "", err
}
@@ -761,7 +1314,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
if len(via) >= soraImageInputMaxRedirects {
return errors.New("too many redirects")
}
- return validateSoraImageURLValue(req.URL)
+ return validateSoraRemoteURLValue(req.URL)
},
}
resp, err := client.Do(req)
@@ -784,51 +1337,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string,
return data, filename, nil
}
-func validateSoraImageURL(raw string) (*url.URL, error) {
+func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) {
+ parsed, err := validateSoraRemoteURL(rawURL)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ client := &http.Client{
+ Timeout: soraVideoInputTimeout,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ if len(via) >= soraVideoInputMaxRedirects {
+ return errors.New("too many redirects")
+ }
+ return validateSoraRemoteURLValue(req.URL)
+ },
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("download video failed: %d", resp.StatusCode)
+ }
+ data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes))
+ if err != nil {
+ return nil, err
+ }
+ if len(data) == 0 {
+ return nil, errors.New("empty video content")
+ }
+ return data, nil
+}
+
+func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) {
+ if maxBytes <= 0 {
+ return nil, errors.New("invalid max bytes limit")
+ }
+ decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
+ limited := io.LimitReader(decoder, maxBytes+1)
+ data, err := io.ReadAll(limited)
+ if err != nil {
+ return nil, err
+ }
+ if int64(len(data)) > maxBytes {
+ return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes)
+ }
+ return data, nil
+}
+
+func validateSoraRemoteURL(raw string) (*url.URL, error) {
if strings.TrimSpace(raw) == "" {
- return nil, errors.New("empty image url")
+ return nil, errors.New("empty remote url")
}
parsed, err := url.Parse(raw)
if err != nil {
- return nil, fmt.Errorf("invalid image url: %w", err)
+ return nil, fmt.Errorf("invalid remote url: %w", err)
}
- if err := validateSoraImageURLValue(parsed); err != nil {
+ if err := validateSoraRemoteURLValue(parsed); err != nil {
return nil, err
}
return parsed, nil
}
-func validateSoraImageURLValue(parsed *url.URL) error {
+func validateSoraRemoteURLValue(parsed *url.URL) error {
if parsed == nil {
- return errors.New("invalid image url")
+ return errors.New("invalid remote url")
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "http" && scheme != "https" {
- return errors.New("only http/https image url is allowed")
+ return errors.New("only http/https remote url is allowed")
}
if parsed.User != nil {
- return errors.New("image url cannot contain userinfo")
+ return errors.New("remote url cannot contain userinfo")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
- return errors.New("image url missing host")
+ return errors.New("remote url missing host")
}
if _, blocked := soraBlockedHostnames[host]; blocked {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if isSoraBlockedIP(ip) {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
return nil
}
ips, err := net.LookupIP(host)
if err != nil {
- return fmt.Errorf("resolve image url failed: %w", err)
+ return fmt.Errorf("resolve remote url failed: %w", err)
}
for _, ip := range ips {
if isSoraBlockedIP(ip) {
- return errors.New("image url is not allowed")
+ return errors.New("remote url is not allowed")
}
}
return nil
diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go
index d6bf9eae..5888fe92 100644
--- a/backend/internal/service/sora_gateway_service_test.go
+++ b/backend/internal/service/sora_gateway_service_test.go
@@ -4,10 +4,16 @@ package service
import (
"context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -18,6 +24,13 @@ type stubSoraClientForPoll struct {
videoStatus *SoraVideoTaskStatus
imageCalls int
videoCalls int
+ enhanced string
+ enhanceErr error
+ storyboard bool
+ videoReq SoraVideoRequest
+ parseErr error
+ postCalls int
+ deleteCalls int
}
func (s *stubSoraClientForPoll) Enabled() bool { return true }
@@ -28,8 +41,60 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac
return "task-image", nil
}
func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) {
+ s.videoReq = req
return "task-video", nil
}
+func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) {
+ s.storyboard = true
+ return "task-video", nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "cameo-1", nil
+}
+func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) {
+ return &SoraCameoStatus{
+ Status: "finalized",
+ StatusMessage: "Completed",
+ DisplayNameHint: "Character",
+ UsernameHint: "user.character",
+ ProfileAssetURL: "https://example.com/avatar.webp",
+ }, nil
+}
+func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) {
+ return []byte("avatar"), nil
+}
+func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) {
+ return "asset-pointer", nil
+}
+func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) {
+ return "character-1", nil
+}
+func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error {
+ return nil
+}
+func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) {
+ s.postCalls++
+ return "s_post", nil
+}
+func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error {
+ s.deleteCalls++
+ return nil
+}
+func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) {
+ if s.parseErr != nil {
+ return "", s.parseErr
+ }
+ return "https://example.com/no-watermark.mp4", nil
+}
+func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) {
+ if s.enhanced != "" {
+ return s.enhanced, s.enhanceErr
+ }
+ return "enhanced prompt", s.enhanceErr
+}
func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) {
s.imageCalls++
return s.imageStatus, nil
@@ -62,6 +127,136 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) {
require.Equal(t, 1, client.imageCalls)
}
+func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ enhanced: "cinematic prompt",
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{
+ ID: 1,
+ Platform: PlatformSora,
+ Status: StatusActive,
+ }
+ body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, "prompt-enhance-short-10s", result.Model)
+}
+
+func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/v.mp4"},
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.True(t, client.storyboard)
+}
+
+func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) {
+ client := &stubSoraClientForPoll{}
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "prompt", result.MediaType)
+ require.Equal(t, 0, client.videoCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ parseErr: errors.New("parse failed"),
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/original.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 0, client.deleteCalls)
+}
+
+func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) {
+ client := &stubSoraClientForPoll{
+ videoStatus: &SoraVideoTaskStatus{
+ Status: "completed",
+ URLs: []string{"https://example.com/original.mp4"},
+ GenerationID: "gen_1",
+ },
+ }
+ cfg := &config.Config{
+ Sora: config.SoraConfig{
+ Client: config.SoraClientConfig{
+ PollIntervalSeconds: 1,
+ MaxPollAttempts: 1,
+ },
+ },
+ }
+ svc := NewSoraGatewayService(client, nil, nil, cfg)
+ account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive}
+ body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`)
+
+ result, err := svc.Forward(context.Background(), nil, account, body, false)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL)
+ require.Equal(t, 1, client.postCalls)
+ require.Equal(t, 1, client.deleteCalls)
+}
+
func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
client := &stubSoraClientForPoll{
videoStatus: &SoraVideoTaskStatus{
@@ -79,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) {
}
service := NewSoraGatewayService(client, nil, nil, cfg)
- urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false)
+ status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false)
require.Error(t, err)
- require.Empty(t, urls)
+ require.Nil(t, status)
require.Contains(t, err.Error(), "reject")
require.Equal(t, 1, client.videoCalls)
}
@@ -175,9 +370,65 @@ func TestSoraProErrorMessage(t *testing.T) {
require.Empty(t, soraProErrorMessage("sora-basic", ""))
}
+func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
+ svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true)
+
+ body := rec.Body.String()
+ require.Contains(t, body, "event: error\n")
+ require.Contains(t, body, "data: [DONE]\n\n")
+
+ lines := strings.Split(body, "\n")
+ require.GreaterOrEqual(t, len(lines), 2)
+ require.Equal(t, "event: error", lines[0])
+ require.True(t, strings.HasPrefix(lines[1], "data: "))
+
+ data := strings.TrimPrefix(lines[1], "data: ")
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(data), &parsed))
+ errObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "upstream_error", errObj["type"])
+ require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"])
+}
+
+func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) {
+ svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
+ sourceHeaders := http.Header{}
+ sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA")
+
+ err := svc.handleSoraRequestError(
+ context.Background(),
+ &Account{ID: 1, Platform: PlatformSora},
+ &SoraUpstreamError{
+ StatusCode: http.StatusForbidden,
+ Message: "forbidden",
+ Headers: sourceHeaders,
+ Body: []byte(`Just a moment...`),
+ },
+ "sora2-landscape-10s",
+ nil,
+ false,
+ )
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.NotNil(t, failoverErr.ResponseHeaders)
+ require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
+
+ sourceHeaders.Set("cf-ray", "mutated-after-return")
+ require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray"))
+}
+
func TestShouldFailoverUpstreamError(t *testing.T) {
svc := NewSoraGatewayService(nil, nil, nil, &config.Config{})
require.True(t, svc.shouldFailoverUpstreamError(401))
+ require.True(t, svc.shouldFailoverUpstreamError(404))
require.True(t, svc.shouldFailoverUpstreamError(429))
require.True(t, svc.shouldFailoverUpstreamError(500))
require.True(t, svc.shouldFailoverUpstreamError(502))
@@ -257,3 +508,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) {
require.NotEmpty(t, data)
require.Contains(t, filename, ".png")
}
+
+func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) {
+ data, err := decodeBase64WithLimit("aGVsbG8=", 3)
+ require.Error(t, err)
+ require.Nil(t, data)
+}
+
+func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) {
+ body := map[string]any{
+ "watermark_free": float64(1),
+ "watermark_fallback_on_failure": float64(0),
+ }
+ opts := parseSoraWatermarkOptions(body)
+ require.True(t, opts.Enabled)
+ require.False(t, opts.FallbackOnFailure)
+}
diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go
index ab095e46..80b20a4b 100644
--- a/backend/internal/service/sora_models.go
+++ b/backend/internal/service/sora_models.go
@@ -17,6 +17,9 @@ type SoraModelConfig struct {
Model string
Size string
RequirePro bool
+ // Prompt-enhance 专用参数
+ ExpansionLevel string
+ DurationS int
}
var soraModelConfigs = map[string]SoraModelConfig{
@@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{
RequirePro: true,
},
"prompt-enhance-short-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 10,
},
"prompt-enhance-short-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 15,
},
"prompt-enhance-short-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "short",
+ DurationS: 20,
},
"prompt-enhance-medium-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 10,
},
"prompt-enhance-medium-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 15,
},
"prompt-enhance-medium-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "medium",
+ DurationS: 20,
},
"prompt-enhance-long-10s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 10,
},
"prompt-enhance-long-15s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 15,
},
"prompt-enhance-long-20s": {
- Type: "prompt_enhance",
+ Type: "prompt_enhance",
+ ExpansionLevel: "long",
+ DurationS: 20,
},
}
diff --git a/backend/internal/service/sora_request_guard.go b/backend/internal/service/sora_request_guard.go
new file mode 100644
index 00000000..a118fe82
--- /dev/null
+++ b/backend/internal/service/sora_request_guard.go
@@ -0,0 +1,266 @@
+package service
+
+import (
+ "fmt"
+ "math"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
+ "github.com/google/uuid"
+)
+
+type soraChallengeCooldownEntry struct {
+ Until time.Time
+ StatusCode int
+ CFRay string
+ ConsecutiveChallenges int
+ LastChallengeAt time.Time
+}
+
+type soraSidecarSessionEntry struct {
+ SessionKey string
+ ExpiresAt time.Time
+ LastUsedAt time.Time
+}
+
+func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int {
+ if c == nil || c.cfg == nil {
+ return 900
+ }
+ cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds
+ if cooldown <= 0 {
+ return 0
+ }
+ return cooldown
+}
+
+func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error {
+ if c == nil {
+ return nil
+ }
+ if account == nil || account.ID <= 0 {
+ return nil
+ }
+ cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
+ if cooldownSeconds <= 0 {
+ return nil
+ }
+ key := soraAccountProxyKey(account, proxyURL)
+ now := time.Now()
+
+ c.challengeCooldownMu.RLock()
+ entry, ok := c.challengeCooldowns[key]
+ c.challengeCooldownMu.RUnlock()
+ if !ok {
+ return nil
+ }
+ if !entry.Until.After(now) {
+ c.challengeCooldownMu.Lock()
+ delete(c.challengeCooldowns, key)
+ c.challengeCooldownMu.Unlock()
+ return nil
+ }
+
+ remaining := int(math.Ceil(entry.Until.Sub(now).Seconds()))
+ if remaining < 1 {
+ remaining = 1
+ }
+ message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining)
+ if entry.ConsecutiveChallenges > 1 {
+ message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges)
+ }
+ if entry.CFRay != "" {
+ message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay)
+ }
+ return &SoraUpstreamError{
+ StatusCode: http.StatusTooManyRequests,
+ Message: message,
+ Headers: make(http.Header),
+ }
+}
+
+func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) {
+ if c == nil {
+ return
+ }
+ if account == nil || account.ID <= 0 {
+ return
+ }
+ cooldownSeconds := c.cloudflareChallengeCooldownSeconds()
+ if cooldownSeconds <= 0 {
+ return
+ }
+ key := soraAccountProxyKey(account, proxyURL)
+ now := time.Now()
+ cfRay := soraerror.ExtractCloudflareRayID(headers, body)
+
+ c.challengeCooldownMu.Lock()
+ c.cleanupExpiredChallengeCooldownsLocked(now)
+
+ streak := 1
+ existing, ok := c.challengeCooldowns[key]
+ if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute {
+ streak = existing.ConsecutiveChallenges + 1
+ }
+ effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak)
+ until := now.Add(time.Duration(effectiveCooldown) * time.Second)
+ if ok && existing.Until.After(until) {
+ until = existing.Until
+ if existing.ConsecutiveChallenges > streak {
+ streak = existing.ConsecutiveChallenges
+ }
+ if cfRay == "" {
+ cfRay = existing.CFRay
+ }
+ }
+ c.challengeCooldowns[key] = soraChallengeCooldownEntry{
+ Until: until,
+ StatusCode: statusCode,
+ CFRay: cfRay,
+ ConsecutiveChallenges: streak,
+ LastChallengeAt: now,
+ }
+ c.challengeCooldownMu.Unlock()
+
+ if c.debugEnabled() {
+ remain := int(math.Ceil(until.Sub(now).Seconds()))
+ if remain < 0 {
+ remain = 0
+ }
+ c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay)
+ }
+}
+
+func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int {
+ if baseSeconds <= 0 {
+ return 0
+ }
+ if streak < 1 {
+ streak = 1
+ }
+ multiplier := streak
+ if multiplier > 4 {
+ multiplier = 4
+ }
+ cooldown := baseSeconds * multiplier
+ if cooldown > 3600 {
+ cooldown = 3600
+ }
+ return cooldown
+}
+
+func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) {
+ if c == nil {
+ return
+ }
+ if account == nil || account.ID <= 0 {
+ return
+ }
+ key := soraAccountProxyKey(account, proxyURL)
+ c.challengeCooldownMu.Lock()
+ _, existed := c.challengeCooldowns[key]
+ if existed {
+ delete(c.challengeCooldowns, key)
+ }
+ c.challengeCooldownMu.Unlock()
+ if existed && c.debugEnabled() {
+ c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key)
+ }
+}
+
+func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string {
+ if c == nil || !c.sidecarSessionReuseEnabled() {
+ return ""
+ }
+ if account == nil || account.ID <= 0 {
+ return ""
+ }
+ key := soraAccountProxyKey(account, proxyURL)
+ now := time.Now()
+ ttlSeconds := c.sidecarSessionTTLSeconds()
+
+ c.sidecarSessionMu.Lock()
+ defer c.sidecarSessionMu.Unlock()
+ c.cleanupExpiredSidecarSessionsLocked(now)
+ if existing, exists := c.sidecarSessions[key]; exists {
+ existing.LastUsedAt = now
+ c.sidecarSessions[key] = existing
+ return existing.SessionKey
+ }
+
+ expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second)
+ if ttlSeconds <= 0 {
+ expiresAt = now.Add(365 * 24 * time.Hour)
+ }
+ newEntry := soraSidecarSessionEntry{
+ SessionKey: "sora-" + uuid.NewString(),
+ ExpiresAt: expiresAt,
+ LastUsedAt: now,
+ }
+ c.sidecarSessions[key] = newEntry
+
+ if c.debugEnabled() {
+ c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds)
+ }
+ return newEntry.SessionKey
+}
+
+func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) {
+ if c == nil || len(c.challengeCooldowns) == 0 {
+ return
+ }
+ for key, entry := range c.challengeCooldowns {
+ if !entry.Until.After(now) {
+ delete(c.challengeCooldowns, key)
+ }
+ }
+}
+
+func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) {
+ if c == nil || len(c.sidecarSessions) == 0 {
+ return
+ }
+ for key, entry := range c.sidecarSessions {
+ if !entry.ExpiresAt.After(now) {
+ delete(c.sidecarSessions, key)
+ }
+ }
+}
+
+func soraAccountProxyKey(account *Account, proxyURL string) string {
+ accountID := int64(0)
+ if account != nil {
+ accountID = account.ID
+ }
+ return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL))
+}
+
+func normalizeSoraProxyKey(proxyURL string) string {
+ raw := strings.TrimSpace(proxyURL)
+ if raw == "" {
+ return "direct"
+ }
+ parsed, err := url.Parse(raw)
+ if err != nil {
+ return strings.ToLower(raw)
+ }
+ scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
+ host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
+ port := strings.TrimSpace(parsed.Port())
+ if host == "" {
+ return strings.ToLower(raw)
+ }
+ if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") {
+ port = ""
+ }
+ if port != "" {
+ host = host + ":" + port
+ }
+ if scheme == "" {
+ scheme = "proxy"
+ }
+ return scheme + "://" + host
+}
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index 9de1c164..a37e0d0a 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -43,10 +43,13 @@ func NewTokenRefreshService(
stopCh: make(chan struct{}),
}
+ openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo)
+ openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts)
+
// 注册平台特定的刷新器
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
- NewOpenAITokenRefresher(openaiOAuthService, accountRepo),
+ openAIRefresher,
NewGeminiTokenRefresher(geminiOAuthService),
NewAntigravityTokenRefresher(antigravityOAuthService),
}
diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go
index 46033f75..0dd3cf45 100644
--- a/backend/internal/service/token_refresher.go
+++ b/backend/internal/service/token_refresher.go
@@ -86,6 +86,7 @@ type OpenAITokenRefresher struct {
openaiOAuthService *OpenAIOAuthService
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
+ syncLinkedSora bool
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) {
r.soraAccountRepo = repo
}
+// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。
+func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) {
+ r.syncLinkedSora = enabled
+}
+
// CanRefresh 检查是否能处理此账号
-// 只处理 openai 平台的 oauth 类型账号
+// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号)
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
- return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) &&
- account.Type == AccountTypeOAuth
+ return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查token是否需要刷新
@@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
}
// 异步同步关联的 Sora 账号(不阻塞主流程)
- if r.accountRepo != nil {
+ if r.accountRepo != nil && r.syncLinkedSora {
go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
}
diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go
index c7505037..264d7912 100644
--- a/backend/internal/service/token_refresher_test.go
+++ b/backend/internal/service/token_refresher_test.go
@@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
})
}
}
+
+func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
+ refresher := &OpenAITokenRefresher{}
+
+ tests := []struct {
+ name string
+ platform string
+ accType string
+ want bool
+ }{
+ {
+ name: "openai oauth - can refresh",
+ platform: PlatformOpenAI,
+ accType: AccountTypeOAuth,
+ want: true,
+ },
+ {
+ name: "sora oauth - cannot refresh directly",
+ platform: PlatformSora,
+ accType: AccountTypeOAuth,
+ want: false,
+ },
+ {
+ name: "openai apikey - cannot refresh",
+ platform: PlatformOpenAI,
+ accType: AccountTypeAPIKey,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: tt.platform,
+ Type: tt.accType,
+ }
+ require.Equal(t, tt.want, refresher.CanRefresh(account))
+ })
+ }
+}
diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go
index b8808e13..f9824183 100644
--- a/backend/internal/service/usage_log.go
+++ b/backend/internal/service/usage_log.go
@@ -26,8 +26,8 @@ type UsageLog struct {
CacheCreationTokens int
CacheReadTokens int
- CacheCreation5mTokens int
- CacheCreation1hTokens int
+ CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"`
+ CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"`
InputCost float64
OutputCost float64
@@ -46,6 +46,9 @@ type UsageLog struct {
UserAgent *string
IPAddress *string
+ // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
+ CacheTTLOverridden bool
+
// 图片生成字段
ImageCount int
ImageSize *string
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 5d712f75..652f9e00 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage {
return NewSoraMediaStorage(cfg)
}
+func ProvideSoraDirectClient(
+ cfg *config.Config,
+ httpUpstream HTTPUpstream,
+ tokenProvider *OpenAITokenProvider,
+ accountRepo AccountRepository,
+ soraAccountRepo SoraAccountRepository,
+) *SoraDirectClient {
+ client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider)
+ client.SetAccountRepositories(accountRepo, soraAccountRepo)
+ return client
+}
+
// ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务
func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService {
svc := NewSoraMediaCleanupService(storage, cfg)
@@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet(
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,
- NewSoraDirectClient,
+ ProvideSoraDirectClient,
wire.Bind(new(SoraClient), new(*SoraDirectClient)),
NewSoraGatewayService,
NewOpenAIGatewayService,
diff --git a/backend/internal/util/soraerror/soraerror.go b/backend/internal/util/soraerror/soraerror.go
new file mode 100644
index 00000000..17712c10
--- /dev/null
+++ b/backend/internal/util/soraerror/soraerror.go
@@ -0,0 +1,170 @@
+package soraerror
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "regexp"
+ "strings"
+)
+
+var (
+ cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`)
+ cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`)
+ htmlChallenge = []string{
+ "window._cf_chl_opt",
+ "just a moment",
+ "enable javascript and cookies to continue",
+ "__cf_chl_",
+ "challenge-platform",
+ }
+)
+
+// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior.
+func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
+ if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests {
+ return false
+ }
+
+ if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") {
+ return true
+ }
+
+ preview := strings.ToLower(TruncateBody(body, 4096))
+ for _, marker := range htmlChallenge {
+ if strings.Contains(preview, marker) {
+ return true
+ }
+ }
+
+ contentType := ""
+ if headers != nil {
+ contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type")))
+ }
+ if strings.Contains(contentType, "text/html") &&
+ (strings.Contains(preview, "= 2 {
+ return strings.TrimSpace(matches[1])
+ }
+ if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 {
+ return strings.TrimSpace(matches[1])
+ }
+ return ""
+}
+
+// FormatCloudflareChallengeMessage appends cf-ray info when available.
+func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
+ rayID := ExtractCloudflareRayID(headers, body)
+ if rayID == "" {
+ return base
+ }
+ return fmt.Sprintf("%s (cf-ray: %s)", base, rayID)
+}
+
+// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts.
+func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return "", ""
+ }
+ if !json.Valid([]byte(trimmed)) {
+ return "", truncateMessage(trimmed, 256)
+ }
+
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(trimmed), &payload); err != nil {
+ return "", truncateMessage(trimmed, 256)
+ }
+
+ code := firstNonEmpty(
+ extractNestedString(payload, "error", "code"),
+ extractRootString(payload, "code"),
+ )
+ message := firstNonEmpty(
+ extractNestedString(payload, "error", "message"),
+ extractRootString(payload, "message"),
+ extractNestedString(payload, "error", "detail"),
+ extractRootString(payload, "detail"),
+ )
+ return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512)
+}
+
+// TruncateBody truncates body text for logging/inspection.
+func TruncateBody(body []byte, max int) string {
+ if max <= 0 {
+ max = 512
+ }
+ raw := strings.TrimSpace(string(body))
+ if len(raw) <= max {
+ return raw
+ }
+ return raw[:max] + "...(truncated)"
+}
+
+func truncateMessage(s string, max int) string {
+ if max <= 0 {
+ return ""
+ }
+ if len(s) <= max {
+ return s
+ }
+ return s[:max] + "...(truncated)"
+}
+
+func firstNonEmpty(values ...string) string {
+ for _, v := range values {
+ if strings.TrimSpace(v) != "" {
+ return v
+ }
+ }
+ return ""
+}
+
+func extractRootString(m map[string]any, key string) string {
+ if m == nil {
+ return ""
+ }
+ v, ok := m[key]
+ if !ok {
+ return ""
+ }
+ s, _ := v.(string)
+ return s
+}
+
+func extractNestedString(m map[string]any, parent, key string) string {
+ if m == nil {
+ return ""
+ }
+ node, ok := m[parent]
+ if !ok {
+ return ""
+ }
+ child, ok := node.(map[string]any)
+ if !ok {
+ return ""
+ }
+ s, _ := child[key].(string)
+ return s
+}
diff --git a/backend/internal/util/soraerror/soraerror_test.go b/backend/internal/util/soraerror/soraerror_test.go
new file mode 100644
index 00000000..4cf11169
--- /dev/null
+++ b/backend/internal/util/soraerror/soraerror_test.go
@@ -0,0 +1,47 @@
+package soraerror
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsCloudflareChallengeResponse(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-mitigated", "challenge")
+ require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`)))
+
+ require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`Just a moment...`)))
+ require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`Just a moment...`)))
+}
+
+func TestExtractCloudflareRayID(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
+ require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
+
+ body := []byte(``)
+ require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
+}
+
+func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
+ code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
+ require.Equal(t, "cf_shield_429", code)
+ require.Equal(t, "rate limited", msg)
+
+ code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
+ require.Equal(t, "unsupported_country_code", code)
+ require.Equal(t, "not available", msg)
+
+ code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
+ require.Equal(t, "", code)
+ require.Equal(t, "plain text", msg)
+}
+
+func TestFormatCloudflareChallengeMessage(t *testing.T) {
+ headers := make(http.Header)
+ headers.Set("cf-ray", "9d03b68c086027a1-SEA")
+ msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
+ require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
+}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 7f37d59c..f7ba5c9e 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
+ strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
@@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if strings.HasPrefix(path, "/api/") ||
strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
+ strings.HasPrefix(path, "/sora/") ||
strings.HasPrefix(path, "/antigravity/") ||
strings.HasPrefix(path, "/setup/") ||
path == "/health" ||
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
index 50f5a323..e2cbcf15 100644
--- a/backend/internal/web/embed_test.go
+++ b/backend/internal/web/embed_test.go
@@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users",
"/v1/models",
"/v1beta/chat",
+ "/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
@@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users",
"/v1/models",
"/v1beta/chat",
+ "/sora/v1/models",
"/antigravity/test",
"/setup/init",
"/health",
diff --git a/backend/migrations/054_drop_legacy_cache_columns.sql b/backend/migrations/054_drop_legacy_cache_columns.sql
new file mode 100644
index 00000000..040828c2
--- /dev/null
+++ b/backend/migrations/054_drop_legacy_cache_columns.sql
@@ -0,0 +1,44 @@
+-- Drop legacy cache token columns that lack the underscore separator.
+-- These were created by GORM's automatic snake_case conversion:
+-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect)
+-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect)
+--
+-- The canonical columns are:
+-- cache_creation_5m_tokens (defined in 001_init.sql)
+-- cache_creation_1h_tokens (defined in 001_init.sql)
+--
+-- Migration 009 already copied data from legacy → canonical columns.
+-- But upgraded instances may still have post-009 writes in legacy columns.
+-- Backfill once more before dropping to prevent data loss.
+
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = 'usage_logs'
+ AND column_name = 'cache_creation5m_tokens'
+ ) THEN
+ UPDATE usage_logs
+ SET cache_creation_5m_tokens = cache_creation5m_tokens
+ WHERE cache_creation_5m_tokens = 0
+ AND cache_creation5m_tokens <> 0;
+ END IF;
+
+ IF EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = 'usage_logs'
+ AND column_name = 'cache_creation1h_tokens'
+ ) THEN
+ UPDATE usage_logs
+ SET cache_creation_1h_tokens = cache_creation1h_tokens
+ WHERE cache_creation_1h_tokens = 0
+ AND cache_creation1h_tokens <> 0;
+ END IF;
+END $$;
+
+ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens;
+ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens;
diff --git a/backend/migrations/055_add_cache_ttl_overridden.sql b/backend/migrations/055_add_cache_ttl_overridden.sql
new file mode 100644
index 00000000..0d42fcf7
--- /dev/null
+++ b/backend/migrations/055_add_cache_ttl_overridden.sql
@@ -0,0 +1,2 @@
+-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account.
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 9fd2d391..c77ab70e 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -374,6 +374,9 @@ sora:
# Max retries for upstream requests
# 上游请求最大重试次数
max_retries: 3
+ # Account+proxy cooldown window after Cloudflare challenge (seconds, 0 to disable)
+ # Cloudflare challenge 后按账号+代理冷却窗口(秒,0 表示关闭)
+ cloudflare_challenge_cooldown_seconds: 900
# Poll interval (seconds)
# 轮询间隔(秒)
poll_interval_seconds: 2
@@ -388,7 +391,11 @@ sora:
recent_task_limit_max: 200
# Enable debug logs for Sora upstream requests
# 启用 Sora 直连调试日志
+ # 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏
debug: false
+ # Allow Sora client to fetch token via OpenAI token provider
+ # 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路)
+ use_openai_token_provider: false
# Optional custom headers (key-value)
# 额外请求头(键值对)
headers: {}
@@ -398,6 +405,27 @@ sora:
# Disable TLS fingerprint for Sora upstream
# 关闭 Sora 上游 TLS 指纹伪装
disable_tls_fingerprint: false
+ # curl_cffi sidecar for Sora only (required)
+ # 仅 Sora 链路使用的 curl_cffi sidecar(必需)
+ curl_cffi_sidecar:
+ # Sora 强制通过 sidecar 请求,必须启用
+ # Sora is forced to use sidecar only; keep enabled=true
+ enabled: true
+ # Sidecar base URL (default endpoint: /request)
+ # sidecar 基础地址(默认请求端点:/request)
+ base_url: "http://sora-curl-cffi-sidecar:8080"
+ # curl_cffi impersonate profile, e.g. chrome131/chrome124/safari18_0
+ # curl_cffi 指纹伪装 profile,例如 chrome131/chrome124/safari18_0
+ impersonate: "chrome131"
+ # Sidecar request timeout (seconds)
+ # sidecar 请求超时(秒)
+ timeout_seconds: 60
+ # Reuse session key per account+proxy to let sidecar persist cookies/session
+ # 按账号+代理复用 session key,让 sidecar 持久化 cookies/session
+ session_reuse_enabled: true
+ # Session TTL in sidecar (seconds)
+ # sidecar 会话 TTL(秒)
+ session_ttl_seconds: 3600
storage:
# Storage type (local only for now)
# 存储类型(首发仅支持 local)
@@ -431,6 +459,13 @@ sora:
# Cron 调度表达式
schedule: "0 3 * * *"
+# Token refresh behavior
+# token 刷新行为控制
+token_refresh:
+ # Whether OpenAI refresh flow is allowed to sync linked Sora accounts
+ # 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token
+ sync_linked_sora_accounts: false
+
# =============================================================================
# API Key Auth Cache Configuration
# API Key 认证缓存配置
diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml
index e5c97bf8..f18a1b64 100644
--- a/deploy/docker-compose.yml
+++ b/deploy/docker-compose.yml
@@ -173,6 +173,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-sub2api}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required}
- POSTGRES_DB=${POSTGRES_DB:-sub2api}
+ - PGDATA=/var/lib/postgresql/data
- TZ=${TZ:-Asia/Shanghai}
networks:
- sub2api-network
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index 0c4856a9..89b11783 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -32,6 +32,7 @@ export async function list(
platform?: string
type?: string
status?: string
+ group?: string
search?: string
},
options?: {
@@ -271,7 +272,7 @@ export async function generateAuthUrl(
*/
export async function exchangeCode(
endpoint: string,
- exchangeData: { session_id: string; code: string; proxy_id?: number }
+ exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number }
): Promise> {
const { data } = await apiClient.post>(endpoint, exchangeData)
return data
@@ -493,7 +494,8 @@ export async function getAntigravityDefaultModelMapping(): Promise> {
const payload: { refresh_token: string; proxy_id?: number } = {
refresh_token: refreshToken
@@ -501,7 +503,29 @@ export async function refreshOpenAIToken(
if (proxyId) {
payload.proxy_id = proxyId
}
- const { data } = await apiClient.post>('/admin/openai/refresh-token', payload)
+ const { data } = await apiClient.post>(endpoint, payload)
+ return data
+}
+
+/**
+ * Validate Sora session token and exchange to access token
+ * @param sessionToken - Sora session token
+ * @param proxyId - Optional proxy ID
+ * @param endpoint - API endpoint path
+ * @returns Token information including access_token
+ */
+export async function validateSoraSessionToken(
+ sessionToken: string,
+ proxyId?: number | null,
+ endpoint: string = '/admin/sora/st2at'
+): Promise> {
+ const payload: { session_token: string; proxy_id?: number } = {
+ session_token: sessionToken
+ }
+ if (proxyId) {
+ payload.proxy_id = proxyId
+ }
+ const { data } = await apiClient.post>(endpoint, payload)
return data
}
@@ -527,6 +551,7 @@ export const accountsAPI = {
generateAuthUrl,
exchangeCode,
refreshOpenAIToken,
+ validateSoraSessionToken,
batchCreate,
batchUpdateCredentials,
bulkUpdate,
diff --git a/frontend/src/api/admin/proxies.ts b/frontend/src/api/admin/proxies.ts
index b6aaf595..5e31ae20 100644
--- a/frontend/src/api/admin/proxies.ts
+++ b/frontend/src/api/admin/proxies.ts
@@ -7,6 +7,7 @@ import { apiClient } from '../client'
import type {
Proxy,
ProxyAccountSummary,
+ ProxyQualityCheckResult,
CreateProxyRequest,
UpdateProxyRequest,
PaginatedResponse,
@@ -143,6 +144,16 @@ export async function testProxy(id: number): Promise<{
return data
}
+/**
+ * Check proxy quality across common AI targets
+ * @param id - Proxy ID
+ * @returns Quality check result
+ */
+export async function checkProxyQuality(id: number): Promise {
+ const { data } = await apiClient.post(`/admin/proxies/${id}/quality-check`)
+ return data
+}
+
/**
* Get proxy usage statistics
* @param id - Proxy ID
@@ -248,6 +259,7 @@ export const proxiesAPI = {
delete: deleteProxy,
toggleStatus,
testProxy,
+ checkProxyQuality,
getStats,
getProxyAccounts,
batchCreate,
diff --git a/frontend/src/components/account/AccountGroupsCell.vue b/frontend/src/components/account/AccountGroupsCell.vue
index 512383a5..37771275 100644
--- a/frontend/src/components/account/AccountGroupsCell.vue
+++ b/frontend/src/components/account/AccountGroupsCell.vue
@@ -41,7 +41,7 @@
>
- {{ t('admin.accounts.allGroups', { count: groups.length }) }}
+ {{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
-
+
@@ -54,6 +54,12 @@
:placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')"
/>
+
+ {{ t('admin.accounts.soraTestHint') }}
+
@@ -135,12 +141,12 @@
- {{ t('admin.accounts.testModel') }}
+ {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }}
+ {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -156,10 +162,10 @@
+