diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 05dd1d1a..fd0c7a41 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -32,7 +32,7 @@ jobs: working-directory: backend run: | go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -severity high -confidence high ./... + gosec -conf .gosec.json -severity high -confidence high ./... frontend-security: runs-on: ubuntu-latest diff --git a/backend/.gosec.json b/backend/.gosec.json new file mode 100644 index 00000000..b34e140c --- /dev/null +++ b/backend/.gosec.json @@ -0,0 +1,5 @@ +{ + "global": { + "exclude": "G704" + } +} diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 62d7b53f..f788a87d 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.74.9 +0.1.83.2 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index be17fb01..a0f8807a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -184,7 +184,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) - soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) + soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, configConfig) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d7d31f08..3c8d4870 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -669,6 +669,7 @@ var ( {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, + {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -684,31 +685,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -717,32 +718,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_model", @@ -757,12 +758,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 2e32d228..678e98c4 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15980,6 +15980,7 @@ type UsageLogMutation struct { addimage_count *int image_size *string media_type *string + cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -17655,6 +17656,42 @@ func (m *UsageLogMutation) ResetMediaType() { delete(m.clearedFields, usagelog.FieldMediaType) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { + m.cache_ttl_overridden = &b +} + +// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation. +func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) { + v := m.cache_ttl_overridden + if v == nil { + return + } + return *v, true +} + +// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err) + } + return oldValue.CacheTTLOverridden, nil +} + +// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field. +func (m *UsageLogMutation) ResetCacheTTLOverridden() { + m.cache_ttl_overridden = nil +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -17860,7 +17897,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 31) + fields := make([]string, 0, 32) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -17951,6 +17988,9 @@ func (m *UsageLogMutation) Fields() []string { if m.media_type != nil { fields = append(fields, usagelog.FieldMediaType) } + if m.cache_ttl_overridden != nil { + fields = append(fields, usagelog.FieldCacheTTLOverridden) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -18022,6 +18062,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageSize() case usagelog.FieldMediaType: return m.MediaType() + case usagelog.FieldCacheTTLOverridden: + return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -18093,6 +18135,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageSize(ctx) case usagelog.FieldMediaType: return m.OldMediaType(ctx) + case usagelog.FieldCacheTTLOverridden: + return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -18314,6 +18358,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetMediaType(v) return nil + case usagelog.FieldCacheTTLOverridden: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheTTLOverridden(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -18736,6 +18787,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldMediaType: m.ResetMediaType() return nil + case usagelog.FieldCacheTTLOverridden: + m.ResetCacheTTLOverridden() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 8da5f84c..5e980be0 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -821,8 +821,12 @@ func init() { usagelogDescMediaType := usagelogFields[29].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) + // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. + usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. + usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[30].Descriptor() + usagelogDescCreatedAt := usagelogFields[31].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index 602f23f6..ffcae840 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -124,6 +124,10 @@ func (UsageLog) Fields() []ent.Field { Optional(). Nillable(), + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + field.Bool("cache_ttl_overridden"). + Default(false), + // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). Default(time.Now). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 63a14197..f6968d0d 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -82,6 +82,8 @@ type UsageLog struct { ImageSize *string `json:"image_size,omitempty"` // MediaType holds the value of the "media_type" field. MediaType *string `json:"media_type,omitempty"` + // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. + CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case usagelog.FieldStream: + case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden: values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) @@ -387,6 +389,12 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.MediaType = new(string) *_m.MediaType = value.String } + case usagelog.FieldCacheTTLOverridden: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) + } else if value.Valid { + _m.CacheTTLOverridden = value.Bool + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -562,6 +570,9 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + builder.WriteString("cache_ttl_overridden=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 3ea5d054..ba97b843 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -74,6 +74,8 @@ const ( FieldImageSize = "image_size" // FieldMediaType holds the string denoting the media_type field in the database. FieldMediaType = "media_type" + // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. + FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -158,6 +160,7 @@ var Columns = []string{ FieldImageCount, FieldImageSize, FieldMediaType, + FieldCacheTTLOverridden, FieldCreatedAt, } @@ -216,6 +219,8 @@ var ( ImageSizeValidator func(string) error // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. MediaTypeValidator func(string) error + // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. + DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -378,6 +383,11 @@ func ByMediaType(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMediaType, opts...).ToFunc() } +// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. +func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 0a33dba2..af960335 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -205,6 +205,11 @@ func MediaType(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) } +// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. +func CacheTTLOverridden(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1520,6 +1525,16 @@ func MediaTypeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) } +// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + +// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index 668a0ede..e0285a5e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -407,6 +407,20 @@ func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { return _c } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { + _c.mutation.SetCacheTTLOverridden(v) + return _c +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate { + if v != nil { + _c.SetCacheTTLOverridden(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -545,6 +559,10 @@ func (_c *UsageLogCreate) defaults() { v := usagelog.DefaultImageCount _c.mutation.SetImageCount(v) } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + v := usagelog.DefaultCacheTTLOverridden + _c.mutation.SetCacheTTLOverridden(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := usagelog.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -646,6 +664,9 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} } } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -785,6 +806,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) _node.MediaType = &value } + if value, ok := _c.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + _node.CacheTTLOverridden = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1448,6 +1473,18 @@ func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { return u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldCacheTTLOverridden, v) + return u +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheTTLOverridden) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2102,6 +2139,20 @@ func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { }) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2922,6 +2973,20 @@ func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { }) } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 22f2613f..b46e5b56 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -632,6 +632,20 @@ func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { return _u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -925,6 +939,9 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.MediaTypeCleared() { _spec.ClearField(usagelog.FieldMediaType, field.TypeString) } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1690,6 +1707,20 @@ func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { return _u } +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -2013,6 +2044,9 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.MediaTypeCleared() { _spec.ClearField(usagelog.FieldMediaType, field.TypeString) } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index b9f31ba9..330ae0c1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -162,6 +162,8 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` + // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) + SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -269,17 +271,30 @@ type SoraConfig struct { // SoraClientConfig 直连 Sora 客户端配置 type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` +} + +// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 +type SoraCurlCFFISidecarConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Impersonate string `mapstructure:"impersonate"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` } // SoraStorageConfig 媒体存储配置 @@ -1111,14 +1126,22 @@ func setDefaults() { viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") viper.SetDefault("sora.client.timeout_seconds", 120) viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) viper.SetDefault("sora.client.poll_interval_seconds", 2) viper.SetDefault("sora.client.max_poll_attempts", 600) viper.SetDefault("sora.client.recent_task_limit", 50) viper.SetDefault("sora.client.recent_task_limit_max", 200) viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.use_openai_token_provider", false) viper.SetDefault("sora.client.headers", map[string]string{}) viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") viper.SetDefault("sora.client.disable_tls_fingerprint", false) + viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") + viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") + viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) viper.SetDefault("sora.storage.type", "local") viper.SetDefault("sora.storage.local_path", "") @@ -1137,6 +1160,7 @@ func setDefaults() { viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -1505,6 +1529,9 @@ func (c *Config) Validate() error { if c.Sora.Client.MaxRetries < 0 { return fmt.Errorf("sora.client.max_retries must be non-negative") } + if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { + return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") + } if c.Sora.Client.PollIntervalSeconds < 0 { return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") } @@ -1521,6 +1548,18 @@ func (c *Config) Validate() error { c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit } + if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") + } + if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") + } + if !c.Sora.Client.CurlCFFISidecar.Enabled { + return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") + } + if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { + return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") + } if c.Sora.Storage.MaxConcurrentDownloads < 0 { return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a3c65c41..dcc60879 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1024,3 +1024,91 @@ func TestValidateConfigErrors(t *testing.T) { }) } } + +func TestSoraCurlCFFISidecarDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Sora.Client.CurlCFFISidecar.Enabled { + t.Fatalf("Sora curl_cffi sidecar should be enabled by default") + } + if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { + t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") + } + if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { + t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") + } + if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { + t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") + } + if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { + t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") + } + if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { + t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") + } +} + +func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.Enabled = false + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { + t.Fatalf("Validate() error = %v, want sidecar enabled error", err) + } +} + +func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { + t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) + } +} + +func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) + } +} + +func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) + } +} diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index b5d1dd0a..34397696 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search) + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0) if err != nil { return nil, err } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 4a9185eb..1aa0cf2b 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) { search = search[:100] } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) + var groupID int64 + if groupIDStr := c.Query("group"); groupIDStr != "" { + groupID, _ = strconv.ParseInt(groupIDStr, 10, 64) + } + + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID) if err != nil { response.ErrorFrom(c, err) return @@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } + // Handle Sora accounts + if account.Platform == service.PlatformSora { + response.Success(c, service.DefaultSoraModels(nil)) + return + } + // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { @@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 20a25222..aeb4097f 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) @@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index cbbfe942..9f3dcf80 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p return s.apiKeys, int64(len(s.apiKeys)), nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) { return s.accounts, int64(len(s.accounts)), nil } @@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr return &service.ProxyTestResult{Success: true, Message: "ok"}, nil } +func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) { + return &service.ProxyQualityCheckResult{ + ProxyID: id, + Score: 95, + Grade: "A", + Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项", + PassedCount: 5, + WarnCount: 0, + FailedCount: 0, + ChallengeCount: 0, + CheckedAt: time.Now().Unix(), + Items: []service.ProxyQualityCheckItem{ + {Target: "base_connectivity", Status: "pass", Message: "ok"}, + {Target: "openai", Status: "pass", HTTPStatus: 401}, + {Target: "anthropic", Status: "pass", HTTPStatus: 401}, + {Target: "gemini", Status: "pass", HTTPStatus: 200}, + {Target: "sora", Status: "pass", HTTPStatus: 401}, + }, + }, nil +} + func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { return s.redeems, int64(len(s.redeems)), nil } diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index ed86fea9..cf43f89e 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { adminService service.AdminService } +func oauthPlatformFromPath(c *gin.Context) string { + if strings.Contains(c.FullPath(), "/admin/sora/") { + return service.PlatformSora + } + return service.PlatformOpenAI +} + // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { return &OpenAIOAuthHandler{ @@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { type OpenAIExchangeCodeRequest struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` } @@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token type OpenAIRefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` + RT string `json:"rt"` + ClientID string `json:"client_id"` ProxyID *int64 `json:"proxy_id"` } // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token +// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + refreshToken = strings.TrimSpace(req.RT) + } + if refreshToken == "" { + response.BadRequest(c, "refresh_token is required") + return + } var proxyURL string if req.ProxyID != nil { @@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID)) if err != nil { response.ErrorFrom(c, err) return @@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// RefreshAccountToken refreshes token for a specific OpenAI account +// ExchangeSoraSessionToken exchanges Sora session token to access token +// POST /api/v1/admin/sora/st2at +func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { + var req struct { + SessionToken string `json:"session_token"` + ST string `json:"st"` + ProxyID *int64 `json:"proxy_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken := strings.TrimSpace(req.SessionToken) + if sessionToken == "" { + sessionToken = strings.TrimSpace(req.ST) + } + if sessionToken == "" { + response.BadRequest(c, "session_token is required") + return + } + + tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, tokenInfo) +} + +// RefreshAccountToken refreshes token for a specific OpenAI/Sora account // POST /api/v1/admin/openai/accounts/:id/refresh +// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { return } - // Ensure account is OpenAI platform - if !account.IsOpenAI() { - response.BadRequest(c, "Account is not an OpenAI account") + platform := oauthPlatformFromPath(c) + if account.Platform != platform { + response.BadRequest(c, "Account platform does not match OAuth endpoint") return } @@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth +// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` Name string `json:"name"` @@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { // Build credentials from token info credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) + platform := oauthPlatformFromPath(c) + // Use email as default name if not provided name := req.Name if name == "" && tokenInfo.Email != "" { name = tokenInfo.Email } if name == "" { - name = "OpenAI OAuth Account" + if platform == service.PlatformSora { + name = "Sora OAuth Account" + } else { + name = "OpenAI OAuth Account" + } } // Create account account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ Name: name, - Platform: "openai", + Platform: platform, Type: "oauth", Credentials: credentials, ProxyID: req.ProxyID, diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index a6758f69..5a9cd7a0 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { response.Success(c, result) } +// CheckQuality handles checking proxy quality across common AI targets. +// POST /api/v1/admin/proxies/:id/quality-check +func (h *ProxyHandler) CheckQuality(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + // GetStats handles getting proxy statistics // GET /api/v1/admin/proxies/:id/stats func (h *ProxyHandler) GetStats(c *gin.Context) { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 3c216d65..dbc7a8bc 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { enabled := true out.EnableSessionIDMasking = &enabled } + // 缓存 TTL 强制替换 + if a.IsCacheTTLOverrideEnabled() { + enabled := true + out.CacheTTLOverrideEnabled = &enabled + target := a.GetCacheTTLOverrideTarget() + out.CacheTTLOverrideTarget = &target + } } return out @@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi CountryCode: p.CountryCode, Region: p.Region, City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, } } @@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ImageSize: l.ImageSize, MediaType: l.MediaType, UserAgent: l.UserAgent, + CacheTTLOverridden: l.CacheTTLOverridden, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index daac42bd..f2605ffc 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -156,6 +156,11 @@ type Account struct { // 从 extra 字段提取,方便前端显示和编辑 EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费 + CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` + CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -197,6 +202,11 @@ type ProxyWithAccountCount struct { CountryCode string `json:"country_code,omitempty"` Region string `json:"region,omitempty"` City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` } type ProxyAccountSummary struct { @@ -280,6 +290,9 @@ type UsageLog struct { // User-Agent UserAgent *string `json:"user_agent"` + // Cache TTL Override 标记 + CacheTTLOverridden bool `json:"cache_ttl_overridden"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 80932899..b958a133 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -20,6 +21,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" @@ -35,6 +37,7 @@ type SoraGatewayHandler struct { concurrencyHelper *ConcurrencyHelper maxAccountSwitches int streamMode string + soraTLSEnabled bool soraMediaSigningKey string soraMediaRoot string } @@ -50,6 +53,7 @@ func NewSoraGatewayHandler( pingInterval := time.Duration(0) maxAccountSwitches := 3 streamMode := "force" + soraTLSEnabled := true signKey := "" mediaRoot := "/app/data/sora" if cfg != nil { @@ -60,6 +64,7 @@ func NewSoraGatewayHandler( if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { streamMode = mode } + soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { mediaRoot = root @@ -72,6 +77,7 @@ func NewSoraGatewayHandler( concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, streamMode: strings.ToLower(streamMode), + soraTLSEnabled: soraTLSEnabled, soraMediaSigningKey: signKey, soraMediaRoot: mediaRoot, } @@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 + var lastFailoverBody []byte + var lastFailoverHeaders http.Header for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") @@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int("last_upstream_status", lastFailoverStatus), + } + if rayID != "" { + fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("last_upstream_content_type", contentType)) + } + reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) return } account := selection.Account setOpsSelectedAccount(c, account.ID, account.Platform) + proxyBound := account.ProxyID != nil + proxyID := int64(0) + if account.ProxyID != nil { + proxyID = *account.ProxyID + } + tlsFingerprintEnabled := h.soraTLSEnabled accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { @@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + reqLog.Warn("sora.account_wait_counter_increment_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) } else if !canWait { reqLog.Info("sora.account_wait_queue_full", zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) @@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { &streamStarted, ) if err != nil { - reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + reqLog.Warn("sora.account_slot_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) h.handleConcurrencyError(c, err, "account", streamStarted) return } @@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { lastFailoverStatus = failoverErr.StatusCode - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_exhausted", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) return } lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody switchCount++ - reqLog.Warn("sora.upstream_failover_switching", + upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), zap.Int("upstream_status", failoverErr.StatusCode), + zap.String("upstream_error_code", upstreamErrCode), + zap.String("upstream_error_message", upstreamErrMsg), zap.Int("switch_count", switchCount), zap.Int("max_switches", maxAccountSwitches), - ) + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_switching", fields...) continue } - reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + reqLog.Error("sora.forward_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) return } @@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { }(result, account, userAgent, clientIP) reqLog.Debug("sora.request_completed", zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), zap.Int("switch_count", switchCount), ) return @@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { - status, errType, errMsg := h.mapUpstreamError(statusCode) +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } -func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { + if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { + baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) + return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if strings.EqualFold(upstreamCode, "cf_shield_429") { + baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." + return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { + switch statusCode { + case 401, 403, 404, 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", upstreamMessage + case 429: + return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage + } + } + switch statusCode { case 401: return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" case 403: return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 404: + if strings.EqualFold(upstreamCode, "unsupported_country_code") { + return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" + } + return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" case 429: return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" case 529: @@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri } } +func cloneHTTPHeaders(headers http.Header) http.Header { + if headers == nil { + return nil + } + return headers.Clone() +} + +func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { + if headers != nil { + mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) + contentType = strings.TrimSpace(headers.Get("content-type")) + if contentType == "" { + contentType = strings.TrimSpace(headers.Get("Content-Type")) + } + } + rayID = soraerror.ExtractCloudflareRayID(headers, body) + return rayID, mitigated, contentType +} + +func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { + message = strings.TrimSpace(message) + if message == "" { + return false + } + if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { + lower := strings.ToLower(message) + if strings.Contains(lower, "Just a moment...`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare challenge") + require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") +} + +func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare shield") + require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") +} + +func TestExtractSoraFailoverHeaderInsights(t *testing.T) { + headers := http.Header{} + headers.Set("cf-mitigated", "challenge") + headers.Set("content-type", "text/html") + body := []byte(``) + + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) + require.Equal(t, "9cff2d62d83bb98d", rayID) + require.Equal(t, "challenge", mitigated) + require.Equal(t, "text/html", contentType) +} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e..423ad925 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -10,6 +10,7 @@ const ( BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaTokenCounting = "token-counting-2024-11-01" + BetaContext1M = "context-1m-2025-08-07" ) // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header @@ -77,6 +78,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-sonnet-4-6", + Type: "model", + DisplayName: "Claude Sonnet 4.6", + CreatedAt: "2026-02-18T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index bb120b57..e3b931be 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,6 +17,8 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // OAuth Client ID for Sora mobile flow (aligned with sora2api) + SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 58b824c9..3f77a57e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -435,10 +435,10 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } func (r *accountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", "", "") + return r.ListWithFilters(ctx, params, "", "", "", "", 0) } -func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (r *accountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { q := r.client.Account.Query() if platform != "" { @@ -458,6 +458,9 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati if search != "" { q = q.Where(dbaccount.NameContainsFold(search)) } + if groupID > 0 { + q = q.Where(dbaccount.HasAccountGroupsWith(dbaccountgroup.GroupIDEQ(groupID))) + } total, err := q.Count(ctx) if err != nil { diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index a054b6d6..4f9d0152 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -238,7 +238,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { tt.setup(client) - accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search) + accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0) s.Require().NoError(err) s.Require().Len(accounts, tt.wantCount) if tt.validate != nil { @@ -305,7 +305,7 @@ func (s *AccountRepoSuite) TestPreload_And_VirtualFields() { s.Require().Len(got.Groups, 1, "expected Groups to be populated") s.Require().Equal(group.ID, got.Groups[0].ID) - accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc") + accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(accounts, 1) diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 394d3a1a..088e7d7f 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + if strings.TrimSpace(clientID) != "" { + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) + } + + clientIDs := []string{ + openai.ClientID, + openai.SoraClientID, + } + seen := make(map[string]struct{}, len(clientIDs)) + var lastErr error + for _, clientID := range clientIDs { + clientID = strings.TrimSpace(clientID) + if clientID == "" { + continue + } + if _, ok := seen[clientID]; ok { + continue + } + seen[clientID] = struct{}{} + + tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) + if err == nil { + return tokenResp, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") +} + +func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { client := createOpenAIReqClient(proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") formData.Set("refresh_token", refreshToken) - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("scope", openai.RefreshScopes) var tokenResp openai.TokenResponse diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index f9df08c8..5938272a 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID == openai.ClientID { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "invalid_grant") + return + } + if clientID == openai.SoraClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at-sora", resp.AccessToken) + require.Equal(s.T(), "rt-sora", resp.RefreshToken) + require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { + const customClientID = "custom-client-id" + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID != customClientID { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-custom", resp.AccessToken) + require.Equal(s.T(), "rt-custom", resp.RefreshToken) + require.Equal(s.T(), []string{customClientID}, seenClientIDs) +} + func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 681b1664..0389a008 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -132,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) image_size, media_type, reasoning_effort, + cache_ttl_overridden, created_at ) VALUES ( $1, $2, $3, $4, $5, @@ -139,7 +140,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -192,6 +193,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) imageSize, mediaType, reasoningEffort, + log.CacheTTLOverridden, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -2221,6 +2223,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e imageSize sql.NullString mediaType sql.NullString reasoningEffort sql.NullString + cacheTTLOverridden bool createdAt time.Time ) @@ -2257,6 +2260,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &imageSize, &mediaType, &reasoningEffort, + &cacheTTLOverridden, &createdAt, ); err != nil { return nil, err @@ -2285,6 +2289,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e BillingType: int8(billingType), Stream: stream, ImageCount: imageCount, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index c574219b..d87d97b5 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -406,6 +406,7 @@ func TestAPIContracts(t *testing.T) { "image_count": 0, "image_size": null, "media_type": null, + "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -945,7 +946,7 @@ func (s *stubAccountRepo) List(ctx context.Context, params pagination.Pagination return nil, nil, errors.New("not implemented") } -func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 704b0907..03d5d025 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { } allowedSet[origin] = struct{}{} } + allowHeaders := []string{ + "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", + "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key", + } + // OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。 + openAIProperties := []string{ + "lang", "package-version", "os", "arch", "retry-count", "runtime", + "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout", + } + for _, prop := range openAIProperties { + allowHeaders = append(allowHeaders, "x-stainless-"+prop) + } + allowHeadersValue := strings.Join(allowHeaders, ", ") return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) @@ -68,12 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") + c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue) c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag") c.Writer.Header().Set("Access-Control-Max-Age", "86400") } - // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 57d54a54..4b4d97c3 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,6 +34,8 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) + // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) + registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -276,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) + sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } +} + func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -306,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.DELETE("/:id", h.Admin.Proxy.Delete) proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete) diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 32f34e0c..69881e70 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -41,16 +43,15 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) - } - - // Sora Chat Completions - soraGateway := r.Group("/v1") - soraGateway.Use(soraBodyLimit) - soraGateway.Use(clientRequestID) - soraGateway.Use(opsErrorLogger) - soraGateway.Use(gin.HandlerFunc(apiKeyAuth)) - { - soraGateway.POST("/chat/completions", h.SoraGateway.ChatCompletions) + // 明确阻止旧入口误用到 Sora,避免客户端把 OpenAI Chat Completions 当作 Sora 入口 + gateway.POST("/chat/completions", func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "For Sora, use /sora/v1/chat/completions. OpenAI should use /v1/responses.", + }, + }) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 592c5139..bce3f98f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -786,6 +786,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) +func (a *Account) IsCacheTTLOverrideEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["cache_ttl_override_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型 +// 返回 "5m" 或 "1h",默认 "5m" +func (a *Account) GetCacheTTLOverrideTarget() string { + if a.Extra == nil { + return "5m" + } + if v, ok := a.Extra["cache_ttl_override_target"]; ok { + if target, ok := v.(string); ok && (target == "5m" || target == "1h") { + return target + } + } + return "5m" +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 3cddd2c7..b301049f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -35,7 +35,7 @@ type AccountRepository interface { Delete(ctx context.Context, id int64) error List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) ListActive(ctx context.Context) ([]Account, error) ListByPlatform(ctx context.Context, platform string) ([]Account, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 414b3678..a466b68a 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -79,7 +79,7 @@ func (s *accountRepoStub) List(ctx context.Context, params pagination.Pagination panic("unexpected List call") } -func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 093f7d4d..a507efb4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,13 +12,17 @@ import ( "io" "log" "net/http" + "net/url" "regexp" "strings" + "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -32,6 +36,10 @@ const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 + soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -39,6 +47,9 @@ type TestEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Data any `json:"data,omitempty"` Success bool `json:"success,omitempty"` Error string `json:"error,omitempty"` } @@ -50,8 +61,13 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + soraTestGuardMu sync.Mutex + soraTestLastRun map[int64]time.Time + soraTestCooldown time.Duration } +const defaultSoraTestCooldown = 10 * time.Second + // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -66,6 +82,8 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + soraTestLastRun: make(map[int64]time.Time), + soraTestCooldown: defaultSoraTestCooldown, } } @@ -467,13 +485,129 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +type soraProbeStep struct { + Name string `json:"name"` + Status string `json:"status"` + HTTPStatus int `json:"http_status,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` +} + +type soraProbeSummary struct { + Status string `json:"status"` + Steps []soraProbeStep `json:"steps"` +} + +type soraProbeRecorder struct { + steps []soraProbeStep +} + +func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { + r.steps = append(r.steps, soraProbeStep{ + Name: name, + Status: status, + HTTPStatus: httpStatus, + ErrorCode: strings.TrimSpace(errorCode), + Message: strings.TrimSpace(message), + }) +} + +func (r *soraProbeRecorder) finalize() soraProbeSummary { + meSuccess := false + partial := false + for _, step := range r.steps { + if step.Name == "me" { + meSuccess = strings.EqualFold(step.Status, "success") + continue + } + if strings.EqualFold(step.Status, "failed") { + partial = true + } + } + + status := "success" + if !meSuccess { + status = "failed" + } else if partial { + status = "partial_success" + } + + return soraProbeSummary{ + Status: status, + Steps: append([]soraProbeStep(nil), r.steps...), + } +} + +func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { + if rec == nil { + return + } + summary := rec.finalize() + code := "" + for _, step := range summary.Steps { + if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { + code = step.ErrorCode + break + } + } + s.sendEvent(c, TestEvent{ + Type: "sora_test_result", + Status: summary.Status, + Code: code, + Data: summary, + }) +} + +func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { + if accountID <= 0 { + return 0, true + } + s.soraTestGuardMu.Lock() + defer s.soraTestGuardMu.Unlock() + + if s.soraTestLastRun == nil { + s.soraTestLastRun = make(map[int64]time.Time) + } + cooldown := s.soraTestCooldown + if cooldown <= 0 { + cooldown = defaultSoraTestCooldown + } + + now := time.Now() + if lastRun, ok := s.soraTestLastRun[accountID]; ok { + elapsed := now.Sub(lastRun) + if elapsed < cooldown { + return cooldown - elapsed, false + } + } + s.soraTestLastRun[accountID] = now + return 0, true +} + +func ceilSeconds(d time.Duration) int { + if d <= 0 { + return 1 + } + sec := int(d / time.Second) + if d%time.Second != 0 { + sec++ + } + if sec < 1 { + sec = 1 + } + return sec +} + // testSoraAccountConnection 测试 Sora 账号的连接 // 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { ctx := c.Request.Context() + recorder := &soraProbeRecorder{} authToken := account.GetCredential("access_token") if authToken == "" { + recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, "No access token available") } @@ -484,11 +618,20 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.Flush() + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, msg) + } + // Send test_start event s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) if err != nil { + recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, "Failed to create request") } @@ -496,15 +639,21 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * req.Header.Set("Authorization", "Bearer "+authToken) req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") // Get proxy URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } + enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() - resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) if err != nil { + recorder.addStep("me", "failed", 0, "network_error", err.Error()) + s.emitSoraProbeSummary(c, recorder) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } defer func() { _ = resp.Body.Close() }() @@ -512,8 +661,33 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * body, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, string(body))) + if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.emitSoraProbeSummary(c, recorder) + s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) + switch { + case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): + recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") + case strings.EqualFold(upstreamCode, "unsupported_country_code"): + recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") + case strings.TrimSpace(upstreamMessage) != "": + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) + default: + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) + } } + recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") // 解析 /me 响应,提取用户信息 var meResp map[string]any @@ -531,10 +705,384 @@ func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account * s.sendEvent(c, TestEvent{Type: "content", Text: info}) } + // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) + subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) + if err == nil { + subReq.Header.Set("Authorization", "Bearer "+authToken) + subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") + + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if subErr != nil { + recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) + } else { + subBody, _ := io.ReadAll(subResp.Body) + _ = subResp.Body.Close() + if subResp.StatusCode == http.StatusOK { + recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") + if summary := parseSoraSubscriptionSummary(subBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) + } + } else { + if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { + recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) + } else { + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) + recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) + } + } + } + } + + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + + s.emitSoraProbeSummary(c, recorder) s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) return nil } +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, + recorder *soraProbeRecorder, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + if recorder != nil { + recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") + } + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } else if recorder != nil { + code := "" + msg := "" + if bootstrapErr != nil { + code = "network_error" + msg = bootstrapErr.Error() + } + recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + +func parseSoraSubscriptionSummary(body []byte) string { + var subResp struct { + Data []struct { + Plan struct { + ID string `json:"id"` + Title string `json:"title"` + } `json:"plan"` + EndTS string `json:"end_ts"` + } `json:"data"` + } + if err := json.Unmarshal(body, &subResp); err != nil { + return "" + } + if len(subResp.Data) == 0 { + return "" + } + + first := subResp.Data[0] + parts := make([]string, 0, 3) + if first.Plan.Title != "" { + parts = append(parts, first.Plan.Title) + } + if first.Plan.ID != "" { + parts = append(parts, first.Plan.ID) + } + if first.EndTS != "" { + parts = append(parts, "end="+first.EndTS) + } + if len(parts) == 0 { + return "" + } + return "Subscription: " + strings.Join(parts, " | ") +} + +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + +func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { + if s == nil || s.cfg == nil { + return true + } + return !s.cfg.Sora.Client.DisableTLSFingerprint +} + +func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) +} + +func extractCloudflareRayID(headers http.Header, body []byte) string { + return soraerror.ExtractCloudflareRayID(headers, body) +} + +func extractSoraEgressIPHint(headers http.Header) string { + if headers == nil { + return "unknown" + } + candidates := []string{ + "x-openai-public-ip", + "x-envoy-external-address", + "cf-connecting-ip", + "x-forwarded-for", + } + for _, key := range candidates { + if value := strings.TrimSpace(headers.Get(key)); value != "" { + return value + } + } + return "unknown" +} + +func sanitizeProxyURLForLog(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + if u.User != nil { + u.User = nil + } + return u.String() +} + +func endpointPathForLog(endpoint string) string { + parsed, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || parsed.Path == "" { + return endpoint + } + return parsed.Path +} + +func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { + accountID := int64(0) + platform := "" + proxyID := "none" + if account != nil { + accountID = account.ID + platform = account.Platform + if account.ProxyID != nil { + proxyID = fmt.Sprintf("%d", *account.ProxyID) + } + } + cfRay := extractCloudflareRayID(headers, body) + if cfRay == "" { + cfRay = "unknown" + } + log.Printf( + "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", + accountID, + platform, + endpoint, + endpointPathForLog(endpoint), + proxyID, + sanitizeProxyURLForLog(proxyURL), + cfRay, + extractSoraEgressIPHint(headers), + ) +} + +func truncateSoraErrorBody(body []byte, max int) string { + return soraerror.TruncateBody(body, max) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go new file mode 100644 index 00000000..3dfac786 --- /dev/null +++ b/backend/internal/service/account_test_service_sora_test.go @@ -0,0 +1,319 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { + resp := newJSONResponse(status, body) + resp.Header.Set(key, value) + return resp +} + +func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + +func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + TLSFingerprint: config.TLSFingerprintConfig{ + Enabled: true, + }, + }, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + DisableTLSFingerprint: false, + }, + }, + }, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) + require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) + require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) + require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) + + body := rec.Body.String() + require.Contains(t, body, `"type":"test_start"`) + require.Contains(t, body, "Sora connection OK - Email: demo@example.com") + require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + body := rec.Body.String() + require.Contains(t, body, "Sora connection OK - User: demo-user") + require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"partial_success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "HTTP 429") + body := rec.Body.String() + require.Contains(t, body, "Cloudflare challenge") +} + +func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "token_invalidated") + body := rec.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"failed"`) + require.Contains(t, body, "token_invalidated") + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + soraTestCooldown: time.Hour, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c1, _ := newSoraTestContext() + err := svc.testSoraAccountConnection(c1, account) + require.NoError(t, err) + + c2, rec2 := newSoraTestContext() + err = svc.testSoraAccountConnection(c2, account) + require.Error(t, err) + require.Contains(t, err.Error(), "测试过于频繁") + body := rec2.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"code":"test_rate_limited"`) + require.Contains(t, body, `"status":"failed"`) + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestSanitizeProxyURLForLog(t *testing.T) { + require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) + require.Equal(t, "", sanitizeProxyURLForLog("")) + require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) +} + +func TestExtractSoraEgressIPHint(t *testing.T) { + h := make(http.Header) + h.Set("x-openai-public-ip", "203.0.113.10") + require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) + + h2 := make(http.Header) + h2.Set("x-envoy-external-address", "198.51.100.9") + require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) + + require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) + require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index f5130527..8614f24a 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -4,11 +4,15 @@ import ( "context" "errors" "fmt" + "io" + "net/http" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" ) // AdminService interface defines admin management operations @@ -39,7 +43,7 @@ type AdminService interface { UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -65,6 +69,7 @@ type AdminService interface { GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) // Redeem code management ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) @@ -288,6 +293,32 @@ type ProxyTestResult struct { CountryCode string `json:"country_code,omitempty"` } +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + // ProxyExitInfo represents proxy exit information from ip-api.com type ProxyExitInfo struct { IP string @@ -302,6 +333,58 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type proxyQualityTarget struct { + Target string + URL string + Method string + AllowedStatuses map[int]struct{} +} + +var proxyQualityTargets = []proxyQualityTarget{ + { + Target: "openai", + URL: "https://api.openai.com/v1/models", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, + { + Target: "anthropic", + URL: "https://api.anthropic.com/v1/messages", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + http.StatusMethodNotAllowed: {}, + http.StatusNotFound: {}, + http.StatusBadRequest: {}, + }, + }, + { + Target: "gemini", + URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + }, + { + Target: "sora", + URL: "https://sora.chatgpt.com/backend/me", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, +} + +const ( + proxyQualityRequestTimeout = 15 * time.Second + proxyQualityResponseHeaderTimeout = 10 * time.Second + proxyQualityMaxBodyBytes = int64(8 * 1024) + proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +) + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository @@ -1054,9 +1137,9 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID) if err != nil { return nil, 0, err } @@ -1690,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR }, nil } +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + ProxyStrict: true, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { + item := ProxyQualityCheckItem{ + Target: target.Target, + } + + req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) + if err != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("构建请求失败: %v", err) + return item + } + req.Header.Set("Accept", "application/json,text/html,*/*") + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + item.Status = "fail" + item.LatencyMs = time.Since(start).Milliseconds() + item.Message = fmt.Sprintf("请求失败: %v", err) + return item + } + defer func() { _ = resp.Body.Close() }() + item.LatencyMs = time.Since(start).Milliseconds() + item.HTTPStatus = resp.StatusCode + + body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) + if readErr != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("读取响应失败: %v", readErr) + return item + } + if int64(len(body)) > proxyQualityMaxBodyBytes { + body = body[:proxyQualityMaxBodyBytes] + } + + if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + item.Status = "challenge" + item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) + item.Message = "Sora 命中 Cloudflare challenge" + return item + } + + if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + item.Status = "pass" + item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) + } else { + item.Status = "warn" + item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) + } + return item + } + + if resp.StatusCode == http.StatusTooManyRequests { + item.Status = "warn" + item.Message = "目标返回 429,可能存在频控" + return item + } + + item.Status = "fail" + item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) + return item +} + +func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { + if result == nil { + return + } + score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 + if score < 0 { + score = 0 + } + result.Score = score + result.Grade = proxyQualityGrade(score) + result.Summary = fmt.Sprintf( + "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", + result.PassedCount, + result.WarnCount, + result.FailedCount, + result.ChallengeCount, + ) +} + +func proxyQualityGrade(score int) string { + switch { + case score >= 90: + return "A" + case score >= 75: + return "B" + case score >= 60: + return "C" + case score >= 40: + return "D" + default: + return "F" + } +} + +func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + if result.ChallengeCount > 0 { + return "challenge" + } + if result.FailedCount > 0 { + return "failed" + } + if result.WarnCount > 0 { + return "warn" + } + if result.PassedCount > 0 { + return "healthy" + } + return "failed" +} + +func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + for _, item := range result.Items { + if item.CFRay != "" { + return item.CFRay + } + } + return "" +} + +func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { + if result == nil { + return false + } + for _, item := range result.Items { + if item.Target == "base_connectivity" { + return item.Status == "pass" + } + } + return false +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { if s.proxyProber == nil || proxy == nil { return @@ -1800,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro proxies[i].CountryCode = info.CountryCode proxies[i].Region = info.Region proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt } } @@ -1807,7 +2159,27 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, if s.proxyLatencyCache == nil || info == nil { return } - if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil { + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) } } diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go new file mode 100644 index 00000000..5a43cd9c --- /dev/null +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { + result := &ProxyQualityCheckResult{ + PassedCount: 2, + WarnCount: 1, + FailedCount: 1, + ChallengeCount: 1, + } + + finalizeProxyQualityResult(result) + + require.Equal(t, 38, result.Score) + require.Equal(t, "F", result.Grade) + require.Contains(t, result.Summary, "通过 2 项") + require.Contains(t, result.Summary, "告警 1 项") + require.Contains(t, result.Summary, "失败 1 项") + require.Contains(t, result.Summary, "挑战 1 项") +} + +func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("cf-ray", "test-ray-123") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Just a moment...")) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "sora", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "challenge", item.Status) + require.Equal(t, http.StatusForbidden, item.HTTPStatus) + require.Equal(t, "test-ray-123", item.CFRay) +} + +func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":[]}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "gemini", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "pass", item.Status) + require.Equal(t, http.StatusOK, item.HTTPStatus) +} + +func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "openai", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "warn", item.Status) + require.Equal(t, http.StatusUnauthorized, item.HTTPStatus) + require.Contains(t, item.Message, "目标可达") +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index d661b710..ff58fd01 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -24,7 +24,7 @@ type accountRepoStubForAdminList struct { listWithFiltersErr error } -func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { s.listWithFiltersCalls++ s.listWithFiltersParams = params s.listWithFiltersPlatform = platform @@ -168,7 +168,7 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0) require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index ed33d992..cf87b282 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -4117,6 +4117,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } // extractClaudeUsage 从非流式 Claude 响应提取 usage @@ -4139,6 +4148,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage if v, ok := u["cache_creation_input_tokens"].(float64); ok { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } return usage } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index e6660399..f100be0b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -31,8 +31,8 @@ type ModelPricing struct { OutputPricePerToken float64 // 每token输出价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 - CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) SupportsCacheBreakdown bool // 是否支持详细的缓存分类 } @@ -172,12 +172,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { if s.pricingService != nil { litellmPricing := s.pricingService.GetModelPricing(model) if litellmPricing != nil { + // 启用 5m/1h 分类计费的条件: + // 1. 存在 1h 价格 + // 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费) + price5m := litellmPricing.CacheCreationInputTokenCost + price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr + enableBreakdown := price1h > 0 && price1h > price5m return &ModelPricing{ InputPricePerToken: litellmPricing.InputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - SupportsCacheBreakdown: false, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, }, nil } } @@ -209,9 +217,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存) - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice + // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { + // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 + breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } else { + breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + } } else { // 标准缓存创建价格(per-token) breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken @@ -280,10 +293,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage // 范围内部分:正常计费 inRangeTokens := UsageTokens{ - InputTokens: inRangeInputTokens, - OutputTokens: tokens.OutputTokens, // 输出只算一次 - CacheCreationTokens: tokens.CacheCreationTokens, - CacheReadTokens: inRangeCacheTokens, + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + CacheCreation5mTokens: tokens.CacheCreation5mTokens, + CacheCreation1hTokens: tokens.CacheCreation1hTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index bd173b96..5eb278f6 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -399,8 +399,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { InputPricePerToken: 3e-6, OutputPricePerToken: 15e-6, SupportsCacheBreakdown: true, - CacheCreation5mPrice: 4.0, // per million tokens - CacheCreation1hPrice: 5.0, // per million tokens + CacheCreation5mPrice: 4e-6, // per token + CacheCreation1hPrice: 5e-6, // per token }, }, } @@ -414,8 +414,8 @@ func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) require.NoError(t, err) - expected5m := float64(100000) / 1_000_000 * 4.0 - expected1h := float64(50000) / 1_000_000 * 5.0 + expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6 + expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6 require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10) } diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index dd58c183..d7108c8d 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -21,3 +21,72 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) } + +func TestStripBetaToken(t *testing.T) { + tests := []struct { + name string + header string + token string + want string + }{ + { + name: "token in middle", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at start", + header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token at end", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token not present", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "empty header", + header: "", + token: "context-1m-2025-08-07", + want: "", + }, + { + name: "with spaces", + header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14", + token: "context-1m-2025-08-07", + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "only token", + header: "context-1m-2025-08-07", + token: "context-1m-2025-08-07", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBetaToken(tt.header, tt.token) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20" + drop := map[string]struct{}{"context-1m-2025-08-07": {}} + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index c7104fde..70d5068b 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -92,7 +92,7 @@ func (m *mockAccountRepoForPlatform) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForPlatform) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForPlatform) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForPlatform) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0502d352..063a5ae6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -349,6 +349,8 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) + CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) } // ForwardResult 转发结果 @@ -373,9 +375,10 @@ type ForwardResult struct { // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true - RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { @@ -3580,12 +3583,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := map[string]struct{}{claude.BetaClaudeCode: {}} + drop := map[string]struct{}{claude.BetaClaudeCode: {}, claude.BetaContext1M: {}} req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + req.Header.Set("anthropic-beta", stripBetaToken(s.getBetaHeader(modelID, clientBetaHeader), claude.BetaContext1M)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -3739,6 +3742,23 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str return strings.Join(out, ",") } +// stripBetaToken removes a single beta token from a comma-separated header value. +// It short-circuits when the token is not present to avoid unnecessary allocations. +func stripBetaToken(header, token string) string { + if !strings.Contains(header, token) { + return header + } + out := make([]string, 0, 8) + for _, p := range strings.Split(header, ",") { + p = strings.TrimSpace(p) + if p == "" || p == token { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -4305,6 +4325,23 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } } + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + rewriteCacheCreationJSON(u, overrideTarget) + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + rewriteCacheCreationJSON(u, overrideTarget) + } + } + } + if needModelReplace { if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { @@ -4432,6 +4469,14 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { usage.InputTokens = msgStart.Message.Usage.InputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } } // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) @@ -4460,6 +4505,68 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { if msgDelta.Usage.CacheReadInputTokens > 0 { usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens } + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return + } + v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64) + v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64) + total := v5m + v1h + if total == 0 { + return + } + switch target { + case "1h": + ccObj["ephemeral_1h_input_tokens"] = total + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + ccObj["ephemeral_5m_input_tokens"] = total + ccObj["ephemeral_1h_input_tokens"] = float64(0) } } @@ -4491,6 +4598,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + // 兼容 Kimi cached_tokens → cache_read_input_tokens if response.Usage.CacheReadInputTokens == 0 { cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() @@ -4502,6 +4617,20 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) @@ -4570,6 +4699,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu result.Usage.InputTokens = 0 } + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { @@ -4617,10 +4753,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { // Token 计费 tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) @@ -4658,6 +4796,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4673,6 +4813,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu ImageCount: result.ImageCount, ImageSize: imageSize, MediaType: mediaType, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4773,6 +4914,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * result.Usage.InputTokens = 0 } + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { @@ -4803,10 +4951,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) @@ -4840,6 +4990,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4854,6 +5006,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -5170,7 +5323,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + drop := map[string]struct{}{claude.BetaContext1M: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -5180,7 +5334,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", beta) + req.Header.Set("anthropic-beta", stripBetaToken(beta, claude.BetaContext1M)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go index 50b998a3..cd690cbd 100644 --- a/backend/internal/service/gateway_streaming_test.go +++ b/backend/internal/service/gateway_streaming_test.go @@ -79,6 +79,22 @@ func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { require.Equal(t, 60, usage.CacheReadInputTokens) } +func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先在 message_start 中写入非零 5m/1h 明细 + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens) + require.Equal(t, 70, usage.CacheCreation1hTokens) + + // 后续 delta 带默认 0,不应覆盖已有非零值 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细") + require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细") + require.Equal(t, 12, usage.OutputTokens) +} + func TestParseSSEUsage_InvalidJSON(t *testing.T) { svc := newMinimalGatewayService() usage := &ClaudeUsage{} diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 2d596f33..86bc9476 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -79,7 +79,7 @@ func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { +func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) { diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index e247e654..6f6261d8 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -14,6 +14,7 @@ import ( type OpenAIOAuthClient interface { ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 5764788a..16befb82 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -99,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.Modified = true } - if _, ok := reqBody["max_output_tokens"]; ok { - delete(reqBody, "max_output_tokens") - result.Modified = true - } - if _, ok := reqBody["max_completion_tokens"]; ok { - delete(reqBody, "max_completion_tokens") - result.Modified = true + // Strip parameters unsupported by codex models via the Responses API. + for _, key := range []string{ + "max_output_tokens", + "max_completion_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + } { + if _, ok := reqBody[key]; ok { + delete(reqBody, key) + result.Modified = true + } } if normalizeCodexTools(reqBody) { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..087ad4ec 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,20 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" "net/http" + "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client := newOpenAIOAuthHTTPClient(proxyURL) + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { + transport := &http.Transport{} + if strings.TrimSpace(proxyURL) != "" { + if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + } + } + return &http.Client{ + Timeout: 120 * time.Second, + Transport: transport, + } +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..fb76f6c1 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..0a2a195f --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,102 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 3842f0a4..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -157,7 +157,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 @@ -206,7 +210,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) p.metrics.refreshFailure.Add(1) refreshFailed = true diff --git a/backend/internal/service/ops_concurrency.go b/backend/internal/service/ops_concurrency.go index f6541d08..92b37e73 100644 --- a/backend/internal/service/ops_concurrency.go +++ b/backend/internal/service/ops_concurrency.go @@ -24,7 +24,7 @@ func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter s accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{ Page: page, PageSize: opsAccountsPageSize, - }, platformFilter, "", "", "") + }, platformFilter, "", "", "", 0) if err != nil { return nil, err } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index ee979d84..41e8b5eb 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -28,14 +28,15 @@ var ( // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -46,14 +47,15 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { - InputCostPerToken *float64 `json:"input_cost_per_token"` - OutputCostPerToken *float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage *float64 `json:"output_cost_per_image"` + InputCostPerToken *float64 `json:"input_cost_per_token"` + OutputCostPerToken *float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage *float64 `json:"output_cost_per_image"` } // PricingService 动态价格服务 @@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } + if entry.CacheCreationInputTokenCostAbove1hr != nil { + pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr + } if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go index 7eb7728f..fc449091 100644 --- a/backend/internal/service/proxy.go +++ b/backend/internal/service/proxy.go @@ -40,6 +40,11 @@ type ProxyWithAccountCount struct { CountryCode string Region string City string + QualityStatus string + QualityScore *int + QualityGrade string + QualitySummary string + QualityChecked *int64 } type ProxyAccountSummary struct { diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go index 4a1cc77b..f54bff88 100644 --- a/backend/internal/service/proxy_latency_cache.go +++ b/backend/internal/service/proxy_latency_cache.go @@ -6,15 +6,21 @@ import ( ) type ProxyLatencyInfo struct { - Success bool `json:"success"` - LatencyMs *int64 `json:"latency_ms,omitempty"` - Message string `json:"message,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - Region string `json:"region,omitempty"` - City string `json:"city,omitempty"` - UpdatedAt time.Time `json:"updated_at"` + Success bool `json:"success"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"` + QualityCFRay string `json:"quality_cf_ray,omitempty"` + UpdatedAt time.Time `json:"updated_at"` } type ProxyLatencyCache interface { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 12c48ab8..b1d767fc 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -381,10 +381,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 2. 尝试从响应头解析重置时间(Anthropic) + // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 + if result := calculateAnthropic429ResetTime(headers); result != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + + // 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推 + windowEnd := result.resetAt + if result.fiveHourReset != nil { + windowEnd = *result.fiveHourReset + } + windowStart := windowEnd.Add(-5 * time.Hour) + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { + slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) + } + + slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second)) + return + } + + // 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容) resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") - // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) + // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) if resetTimestamp == "" { switch account.Platform { case PlatformOpenAI: @@ -497,6 +518,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +// anthropic429Result holds the parsed Anthropic 429 rate-limit information. +type anthropic429Result struct { + resetAt time.Time // The correct reset time to use for SetRateLimited + fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available +} + +// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers +// to determine which window (5h or 7d) actually triggered the 429. +// +// Headers used: +// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold +// - anthropic-ratelimit-unified-5h-reset +// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold +// - anthropic-ratelimit-unified-7d-reset +// +// Returns nil when the per-window headers are absent (caller should fall back to +// the aggregated anthropic-ratelimit-unified-reset header). +func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result { + reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset") + reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset") + + if reset5hStr == "" && reset7dStr == "" { + return nil + } + + var reset5h, reset7d *time.Time + if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset5h = &t + } + if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset7d = &t + } + + is5hExceeded := isAnthropicWindowExceeded(headers, "5h") + is7dExceeded := isAnthropicWindowExceeded(headers, "7d") + + slog.Info("anthropic_429_window_analysis", + "is_5h_exceeded", is5hExceeded, + "is_7d_exceeded", is7dExceeded, + "reset_5h", reset5hStr, + "reset_7d", reset7dStr, + ) + + // Select the correct reset time based on which window(s) are exceeded. + var chosen *time.Time + switch { + case is5hExceeded && is7dExceeded: + // Both exceeded → prefer 7d (longer cooldown), fall back to 5h + chosen = reset7d + if chosen == nil { + chosen = reset5h + } + case is5hExceeded: + chosen = reset5h + case is7dExceeded: + chosen = reset7d + default: + // Neither flag clearly exceeded — pick the sooner reset as best guess + chosen = pickSooner(reset5h, reset7d) + } + + if chosen == nil { + return nil + } + return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h} +} + +// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window +// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers. +func isAnthropicWindowExceeded(headers http.Header, window string) bool { + prefix := "anthropic-ratelimit-unified-" + window + "-" + + // Check surpassed-threshold first (most explicit signal) + if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") { + return true + } + + // Fall back to utilization >= 1.0 + if utilStr := headers.Get(prefix + "utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 { + // Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0 + return true + } + } + + return false +} + +// pickSooner returns whichever of the two time pointers is earlier. +// If only one is non-nil, it is returned. If both are nil, returns nil. +func pickSooner(a, b *time.Time) *time.Time { + switch { + case a != nil && b != nil: + if a.Before(*b) { + return a + } + return b + case a != nil: + return a + default: + return b + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // diff --git a/backend/internal/service/ratelimit_service_anthropic_test.go b/backend/internal/service/ratelimit_service_anthropic_test.go new file mode 100644 index 00000000..eaeaf30e --- /dev/null +++ b/backend/internal/service/ratelimit_service_anthropic_test.go @@ -0,0 +1,202 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) + + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + // fiveHourReset should still be populated for session window calculation + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) +} + +func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + if result != nil { + t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) { + result := calculateAnthropic429ResetTime(http.Header{}) + if result != nil { + t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + if result.fiveHourReset != nil { + t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset) + } +} + +func TestIsAnthropicWindowExceeded(t *testing.T) { + tests := []struct { + name string + headers http.Header + window string + expected bool + }{ + { + name: "utilization above 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"), + window: "5h", + expected: true, + }, + { + name: "utilization exactly 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"), + window: "5h", + expected: true, + }, + { + name: "utilization below 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"), + window: "5h", + expected: false, + }, + { + name: "surpassed-threshold true", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold True (case insensitive)", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold false", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"), + window: "7d", + expected: false, + }, + { + name: "no headers", + headers: http.Header{}, + window: "5h", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isAnthropicWindowExceeded(tc.headers, tc.window) + if got != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, got) + } + }) + } +} + +// assertAnthropicResult is a test helper that verifies the result is non-nil and +// has the expected resetAt unix timestamp. +func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) { + t.Helper() + if result == nil { + t.Fatal("expected non-nil result") + return // unreachable, but satisfies staticcheck SA5011 + } + want := time.Unix(wantUnix, 0) + if !result.resetAt.Equal(want) { + t.Errorf("expected resetAt=%v, got %v", want, result.resetAt) + } +} + +func makeHeader(key, value string) http.Header { + h := http.Header{} + h.Set(key, value) + return h +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index de097d5e..7cecfa03 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "hash/fnv" "io" "log" "math/rand" @@ -17,12 +18,16 @@ import ( "net/textproto" "net/url" "path" + "sort" "strconv" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/google/uuid" "github.com/tidwall/gjson" "golang.org/x/crypto/sha3" @@ -34,6 +39,11 @@ const ( soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" ) +var ( + soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + soraOAuthTokenURL = "https://auth.openai.com/oauth/token" +) + const ( soraPowMaxIteration = 500000 ) @@ -86,9 +96,20 @@ var soraDesktopUserAgents = []string{ "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", } +var soraMobileUserAgents = []string{ + "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)", + "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)", + "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)", + "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)", + "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)", +} + var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) var soraRandMu sync.Mutex var soraPerfStart = time.Now() +var soraPowTokenGenerator = soraGetPowToken // SoraClient 定义直连 Sora 的任务操作接口。 type SoraClient interface { @@ -96,6 +117,18 @@ type SoraClient interface { UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) + UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) + GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) + DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) + UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) + FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) + SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error + DeleteCharacter(ctx context.Context, account *Account, characterID string) error + PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) + DeletePost(ctx context.Context, account *Account, postID string) error + GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) } @@ -117,6 +150,17 @@ type SoraVideoRequest struct { Size string MediaID string RemixTargetID string + CameoIDs []string +} + +// SoraStoryboardRequest 分镜视频生成请求参数 +type SoraStoryboardRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string } // SoraImageTaskStatus 图片任务状态 @@ -130,11 +174,32 @@ type SoraImageTaskStatus struct { // SoraVideoTaskStatus 视频任务状态 type SoraVideoTaskStatus struct { - ID string - Status string - ProgressPct int - URLs []string - ErrorMsg string + ID string + Status string + ProgressPct int + URLs []string + GenerationID string + ErrorMsg string +} + +// SoraCameoStatus 角色处理中间态 +type SoraCameoStatus struct { + Status string + StatusMessage string + DisplayNameHint string + UsernameHint string + ProfileAssetURL string + InstructionSetHint any + InstructionSet any +} + +// SoraCharacterFinalizeRequest 角色定稿请求参数 +type SoraCharacterFinalizeRequest struct { + CameoID string + Username string + DisplayName string + ProfileAssetPointer string + InstructionSet any } // SoraUpstreamError 上游错误 @@ -157,26 +222,110 @@ func (e *SoraUpstreamError) Error() string { // SoraDirectClient 直连 Sora 实现 type SoraDirectClient struct { - cfg *config.Config - httpUpstream HTTPUpstream - tokenProvider *OpenAITokenProvider + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + baseURL string + challengeCooldownMu sync.RWMutex + challengeCooldowns map[string]soraChallengeCooldownEntry + sidecarSessionMu sync.RWMutex + sidecarSessions map[string]soraSidecarSessionEntry +} + +type soraRequestTraceContextKey struct{} + +type soraRequestTrace struct { + ID string + ProxyKey string + UAHash string } // NewSoraDirectClient 创建 Sora 直连客户端 func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { - return &SoraDirectClient{ - cfg: cfg, - httpUpstream: httpUpstream, - tokenProvider: tokenProvider, + baseURL := "" + if cfg != nil { + rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/") + baseURL = normalizeSoraBaseURL(rawBaseURL) + if rawBaseURL != "" && baseURL != rawBaseURL { + log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL)) + } } + return &SoraDirectClient{ + cfg: cfg, + httpUpstream: httpUpstream, + tokenProvider: tokenProvider, + baseURL: baseURL, + challengeCooldowns: make(map[string]soraChallengeCooldownEntry), + sidecarSessions: make(map[string]soraSidecarSessionEntry), + } +} + +func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { + if c == nil { + return + } + c.accountRepo = accountRepo + c.soraAccountRepo = soraAccountRepo } // Enabled 判断是否启用 Sora 直连 func (c *SoraDirectClient) Enabled() bool { - if c == nil || c.cfg == nil { + if c == nil { return false } - return strings.TrimSpace(c.cfg.Sora.Client.BaseURL) != "" + if strings.TrimSpace(c.baseURL) != "" { + return true + } + if c.cfg == nil { + return false + } + return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != "" +} + +// PreflightCheck 在创建任务前执行账号能力预检。 +// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。 +func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { + if modelCfg.Type != "video" { + return nil + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Accept", "application/json") + body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) + if err != nil { + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { + return &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "当前账号未开通 Sora2 能力或无可用配额", + Headers: upstreamErr.Headers, + Body: upstreamErr.Body, + } + } + return err + } + + rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool() + remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining") + if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) { + msg := "当前账号 Sora2 可用配额不足" + if requestedModel != "" { + msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: msg, + Headers: http.Header{}, + } + } + return nil } func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { @@ -187,6 +336,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) if filename == "" { filename = "image.png" } @@ -213,10 +364,10 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da return "", err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", writer.FormDataContentType()) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) if err != nil { return "", err } @@ -232,6 +383,9 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) operation := "simple_compose" inpaintItems := []map[string]any{} if strings.TrimSpace(req.MediaID) != "" { @@ -252,7 +406,7 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account "n_frames": 1, "inpaint_items": inpaintItems, } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", "application/json") headers.Set("Origin", "https://sora.chatgpt.com") headers.Set("Referer", "https://sora.chatgpt.com/") @@ -261,13 +415,13 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account if err != nil { return "", err } - sentinel, err := c.generateSentinelToken(ctx, account, token) + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) if err != nil { return "", err } headers.Set("openai-sentinel-token", sentinel) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -283,6 +437,9 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account if err != nil { return "", err } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) orientation := req.Orientation if orientation == "" { orientation = "landscape" @@ -320,9 +477,12 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account payload["remix_target_id"] = req.RemixTargetID payload["cameo_ids"] = []string{} payload["cameo_replacements"] = map[string]any{} + } else if len(req.CameoIDs) > 0 { + payload["cameo_ids"] = req.CameoIDs + payload["cameo_replacements"] = map[string]any{} } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + headers := c.buildBaseHeaders(token, userAgent) headers.Set("Content-Type", "application/json") headers.Set("Origin", "https://sora.chatgpt.com") headers.Set("Referer", "https://sora.chatgpt.com/") @@ -330,13 +490,13 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account if err != nil { return "", err } - sentinel, err := c.generateSentinelToken(ctx, account, token) + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) if err != nil { return "", err } headers.Set("openai-sentinel-token", sentinel) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -347,6 +507,469 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account return taskID, nil } +func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "title": "Draft your video", + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "storyboard_id": nil, + "inpaint_items": inpaintItems, + "remix_target_id": nil, + "model": model, + "metadata": nil, + "style_id": nil, + "cameo_ids": nil, + "cameo_replacements": nil, + "audio_caption": nil, + "audio_transcript": nil, + "video_caption": nil, + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if taskID == "" { + return "", errors.New("storyboard task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty video data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`) + partHeader.Set("Content-Type", "video/mp4") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("timestamps", "0,3"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false) + if err != nil { + return "", err + } + cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if cameoID == "" { + return "", errors.New("character upload response missing id") + } + return cameoID, nil +} + +func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return &SoraCameoStatus{ + Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()), + StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()), + DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()), + UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()), + ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()), + InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(), + InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(), + }, nil +} + +func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Accept", "image/*,*/*;q=0.8") + + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + strings.TrimSpace(imageURL), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return respBody, nil +} + +func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty character image") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`) + partHeader.Set("Content-Type", "image/webp") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("use_case", "profile"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false) + if err != nil { + return "", err + } + assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String()) + if assetPointer == "" { + return "", errors.New("character image upload response missing asset_pointer") + } + return assetPointer, nil +} + +func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + payload := map[string]any{ + "cameo_id": req.CameoID, + "username": req.Username, + "display_name": req.DisplayName, + "profile_asset_pointer": req.ProfileAssetPointer, + "instruction_set": nil, + "safety_instruction_set": nil, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String()) + if characterID == "" { + return "", errors.New("character finalize response missing character_id") + } + return characterID, nil +} + +func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + payload := map[string]any{"visibility": "public"} + body, err := json.Marshal(payload) + if err != nil { + return err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodPost, + c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"), + headers, + bytes.NewReader(body), + false, + ) + return err +} + +func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + payload := map[string]any{ + "attachments_to_create": []map[string]any{ + { + "generation_id": generationID, + "kind": "sora", + }, + }, + "post_text": "", + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String()) + if postID == "" { + return "", errors.New("watermark-free publish response missing post.id") + } + return postID, nil +} + +func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/post/"+strings.TrimSpace(postID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/") + if parseURL == "" { + return "", errors.New("custom parse url is required") + } + if strings.TrimSpace(parseToken) == "" { + return "", errors.New("custom parse token is required") + } + shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID) + payload := map[string]any{ + "url": shareURL, + "token": strings.TrimSpace(parseToken), + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256)) + } + downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String()) + if downloadLink == "" { + return "", errors.New("custom parse response missing download_link") + } + return downloadLink, nil +} + +func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + if strings.TrimSpace(expansionLevel) == "" { + expansionLevel = "medium" + } + if durationS <= 0 { + durationS = 10 + } + + payload := map[string]any{ + "prompt": prompt, + "expansion_level": expansionLevel, + "duration_s": durationS, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String()) + if enhancedPrompt == "" { + return "", errors.New("enhance_prompt response missing enhanced_prompt") + } + return enhancedPrompt, nil +} + func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) if err != nil { @@ -373,12 +996,14 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac if err != nil { return nil, false, err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) if limit <= 0 { limit = 20 } endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit) - respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL(endpoint), headers, nil, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false) if err != nil { return nil, false, err } @@ -435,9 +1060,11 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t if err != nil { return nil, err } - headers := c.buildBaseHeaders(token, c.defaultUserAgent()) + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) - respBody, _, err := c.doRequest(ctx, account, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) if err != nil { return nil, err } @@ -466,7 +1093,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t } } - respBody, _, err = c.doRequest(ctx, account, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) + respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) if err != nil { return nil, err } @@ -475,6 +1102,7 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t if draft.Get("task_id").String() != taskID { return true } + generationID := strings.TrimSpace(draft.Get("id").String()) kind := strings.TrimSpace(draft.Get("kind").String()) reason := strings.TrimSpace(draft.Get("reason_str").String()) if reason == "" { @@ -491,15 +1119,17 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t msg = "Content violates guardrails" } draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "failed", - ErrorMsg: msg, + ID: taskID, + Status: "failed", + GenerationID: generationID, + ErrorMsg: msg, } } else { draftFound = &SoraVideoTaskStatus{ - ID: taskID, - Status: "completed", - URLs: []string{urlStr}, + ID: taskID, + Status: "completed", + GenerationID: generationID, + URLs: []string{urlStr}, } } return false @@ -512,9 +1142,10 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t } func (c *SoraDirectClient) buildURL(endpoint string) string { - base := "" - if c != nil && c.cfg != nil { - base = strings.TrimRight(strings.TrimSpace(c.cfg.Sora.Client.BaseURL), "/") + base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/") + if base == "" && c != nil && c.cfg != nil { + base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL) + c.baseURL = base } if base == "" { return endpoint @@ -536,18 +1167,278 @@ func (c *SoraDirectClient) defaultUserAgent() string { return ua } +func (c *SoraDirectClient) taskUserAgent() string { + if c != nil && c.cfg != nil { + if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" { + return ua + } + } + if len(soraMobileUserAgents) > 0 { + return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))] + } + if len(soraDesktopUserAgents) > 0 { + return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))] + } + return soraDefaultUserAgent +} + +func (c *SoraDirectClient) resolveProxyURL(account *Account) string { + if account == nil || account.ProxyID == nil || account.Proxy == nil { + return "" + } + return strings.TrimSpace(account.Proxy.URL()) +} + func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") } - if c.tokenProvider != nil { - return c.tokenProvider.GetAccessToken(ctx, account) + + allowProvider := c.allowOpenAITokenProvider(account) + var providerErr error + if allowProvider && c.tokenProvider != nil { + token, err := c.tokenProvider.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(token) != "" { + c.logTokenSource(account, "openai_token_provider") + return token, nil + } + providerErr = err + if err != nil && c.debugEnabled() { + c.debugLogf( + "token_provider_failed account_id=%d platform=%s err=%s", + account.ID, + account.Platform, + logredact.RedactText(err.Error()), + ) + } } token := strings.TrimSpace(account.GetCredential("access_token")) - if token == "" { - return "", errors.New("access_token not found") + if token != "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { + refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") + if refreshErr == nil && strings.TrimSpace(refreshed) != "" { + c.logTokenSource(account, "refresh_token_recovered") + return refreshed, nil + } + if refreshErr != nil && c.debugEnabled() { + c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error())) + } + } + c.logTokenSource(account, "account_credentials") + return token, nil } - return token, nil + + recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") + if recoverErr == nil && strings.TrimSpace(recovered) != "" { + c.logTokenSource(account, "session_or_refresh_recovered") + return recovered, nil + } + if recoverErr != nil && c.debugEnabled() { + c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error())) + } + if providerErr != nil { + return "", providerErr + } + if c.tokenProvider != nil && !allowProvider { + c.logTokenSource(account, "account_credentials(provider_disabled)") + } + return "", errors.New("access_token not found") +} + +func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { + accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) + if err == nil && strings.TrimSpace(accessToken) != "" { + c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) + c.logTokenRecover(account, "session_token", reason, true, nil) + return accessToken, nil + } + c.logTokenRecover(account, "session_token", reason, false, err) + } + + refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) + if refreshToken == "" { + return "", errors.New("session_token/refresh_token not found") + } + accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken) + if err != nil { + c.logTokenRecover(account, "refresh_token", reason, false, err) + return "", err + } + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("refreshed access_token is empty") + } + c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "") + c.logTokenRecover(account, "refresh_token", reason, true, nil) + return accessToken, nil +} + +func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { + headers := http.Header{} + headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", c.defaultUserAgent()) + body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false) + if err != nil { + return "", "", err + } + accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) + if accessToken == "" { + return "", "", errors.New("session exchange missing accessToken") + } + expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) + return accessToken, expiresAt, nil +} + +func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) { + clientIDs := []string{ + strings.TrimSpace(account.GetCredential("client_id")), + openaioauth.SoraClientID, + openaioauth.ClientID, + } + tried := make(map[string]struct{}, len(clientIDs)) + var lastErr error + + for _, clientID := range clientIDs { + if clientID == "" { + continue + } + if _, ok := tried[clientID]; ok { + continue + } + tried[clientID] = struct{}{} + + formData := url.Values{} + formData.Set("client_id", clientID) + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") + headers := http.Header{} + headers.Set("Accept", "application/json") + headers.Set("Content-Type", "application/x-www-form-urlencoded") + headers.Set("User-Agent", c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) + if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error())) + } + continue + } + accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String()) + if accessToken == "" { + lastErr = errors.New("oauth refresh response missing access_token") + continue + } + newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String()) + expiresIn := gjson.GetBytes(respBody, "expires_in").Int() + expiresAt := "" + if expiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + } + return accessToken, newRefreshToken, expiresAt, nil + } + + if lastErr != nil { + return "", "", "", lastErr + } + return "", "", "", errors.New("no available client_id for refresh_token exchange") +} + +func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { + if account == nil { + return + } + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + if strings.TrimSpace(accessToken) != "" { + account.Credentials["access_token"] = accessToken + } + if strings.TrimSpace(refreshToken) != "" { + account.Credentials["refresh_token"] = refreshToken + } + if strings.TrimSpace(expiresAt) != "" { + account.Credentials["expires_at"] = expiresAt + } + if strings.TrimSpace(sessionToken) != "" { + account.Credentials["session_token"] = sessionToken + } + + if c.accountRepo != nil { + if err := c.accountRepo.Update(ctx, account); err != nil { + if c.debugEnabled() { + c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + } + c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) +} + +func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { + if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { + return + } + updates := make(map[string]any) + if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { + updates["access_token"] = accessToken + updates["refresh_token"] = refreshToken + } + if strings.TrimSpace(sessionToken) != "" { + updates["session_token"] = sessionToken + } + if len(updates) == 0 { + return + } + if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { + c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } +} + +func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) { + if !c.debugEnabled() || account == nil { + return + } + if success { + c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + if err == nil { + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error())) +} + +func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool { + if c == nil || c.tokenProvider == nil { + return false + } + if account != nil && account.Platform == PlatformSora { + return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider + } + return true +} + +func (c *SoraDirectClient) logTokenSource(account *Account, source string) { + if !c.debugEnabled() || account == nil { + return + } + c.debugLogf( + "token_selected account_id=%d platform=%s account_type=%s source=%s", + account.ID, + account.Platform, + account.Type, + source, + ) } func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { @@ -570,9 +1461,30 @@ func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header } func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { + return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry) +} + +func (c *SoraDirectClient) doRequestWithProxy( + ctx context.Context, + account *Account, + proxyURL string, + method, + urlStr string, + headers http.Header, + body io.Reader, + allowRetry bool, +) ([]byte, http.Header, error) { if strings.TrimSpace(urlStr) == "" { return nil, nil, errors.New("empty upstream url") } + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + proxyURL = c.resolveProxyURL(account) + } + if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil { + return nil, nil, cooldownErr + } + traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent")) timeout := 0 if c != nil && c.cfg != nil { timeout = c.cfg.Sora.Client.TimeoutSeconds @@ -600,7 +1512,29 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } attempts := maxRetries + 1 + authRecovered := false + authRecoverExtraAttemptGranted := false + challengeRetried := false + sawCFChallenge := false + var lastErr error for attempt := 1; attempt <= attempts; attempt++ { + if c.debugEnabled() { + c.debugLogf( + "request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + timeout, + len(bodyBytes), + proxyURL != "", + traceProxyKey, + traceUAHash, + formatSoraHeaders(headers), + ) + } + var reader io.Reader if bodyBytes != nil { reader = bytes.NewReader(bodyBytes) @@ -612,13 +1546,24 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth req.Header = headers.Clone() start := time.Now() - proxyURL := "" - if account != nil && account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } resp, err := c.doHTTP(req, proxyURL, account) if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf( + "request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + logredact.RedactText(err.Error()), + ) + } if attempt < attempts && allowRetry { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts) + } c.sleepRetry(attempt) continue } @@ -632,24 +1577,119 @@ func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, meth } if c.cfg != nil && c.cfg.Sora.Client.Debug { - log.Printf("[SoraClient] %s %s status=%d cost=%s", method, sanitizeSoraLogURL(urlStr), resp.StatusCode, time.Since(start)) + c.debugLogf( + "response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + time.Since(start), + len(respBody), + formatSoraHeaders(resp.Header), + ) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody) + isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody) + if isCFChallenge { + sawCFChallenge = true + c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody) + if allowRetry && attempt < attempts && !challengeRetried { + challengeRetried = true + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } + c.sleepRetry(attempt) + continue + } + } + if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil { + if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" { + headers.Set("Authorization", "Bearer "+recovered) + authRecovered = true + if attempt == attempts && !authRecoverExtraAttemptGranted { + attempts++ + authRecoverExtraAttemptGranted = true + } + if c.debugEnabled() { + c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode) + } + continue + } else if recoverErr != nil && c.debugEnabled() { + c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error())) + } + } + if c.debugEnabled() { + c.debugLogf( + "response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + summarizeSoraResponseBody(respBody, 512), + ) + } + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr) + lastErr = upstreamErr + if isCFChallenge { + return nil, resp.Header, upstreamErr + } if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } c.sleepRetry(attempt) continue } return nil, resp.Header, upstreamErr } + if sawCFChallenge { + c.clearCloudflareChallengeCooldown(account, proxyURL) + } return respBody, resp.Header, nil } + if lastErr != nil { + return nil, nil, lastErr + } return nil, nil, errors.New("upstream retries exhausted") } +func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden: + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return false + } + // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。 + path := strings.ToLower(strings.TrimSpace(parsed.Path)) + if path == "/api/auth/session" { + return false + } + return true + default: + return false + } +} + func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { - enableTLS := c != nil && c.cfg != nil && c.cfg.Gateway.TLSFingerprint.Enabled && !c.cfg.Sora.Client.DisableTLSFingerprint + if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled { + resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account) + if err != nil { + return nil, err + } + return resp, nil + } + + enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint if c.httpUpstream != nil { accountID := int64(0) accountConcurrency := 0 @@ -670,9 +1710,14 @@ func (c *SoraDirectClient) sleepRetry(attempt int) { time.Sleep(backoff) } -func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte) error { +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error { msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) msg = sanitizeUpstreamErrorMessage(msg) + if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") { + if hint := soraBaseURLNotFoundHint(requestURL); hint != "" { + msg = strings.TrimSpace(msg + " " + hint) + } + } if msg == "" { msg = truncateForLog(body, 256) } @@ -684,10 +1729,52 @@ func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, b } } -func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken string) (string, error) { +func normalizeSoraBaseURL(raw string) string { + trimmed := strings.TrimRight(strings.TrimSpace(raw), "/") + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return trimmed + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return trimmed + } + pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/") + switch pathVal { + case "", "/": + parsed.Path = "/backend" + case "/backend-api": + parsed.Path = "/backend" + } + return strings.TrimRight(parsed.String(), "/") +} + +func soraBaseURLNotFoundHint(requestURL string) string { + parsed, err := url.Parse(strings.TrimSpace(requestURL)) + if err != nil || parsed.Host == "" { + return "" + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return "" + } + pathVal := strings.TrimSpace(parsed.Path) + if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" { + return "" + } + return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" +} + +func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) { reqID := uuid.NewString() - userAgent := soraRandChoice(soraDesktopUserAgents) - powToken := soraGetPowToken(userAgent) + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" { + userAgent = c.taskUserAgent() + } + powToken := soraPowTokenGenerator(userAgent) payload := map[string]any{ "p": powToken, "flow": soraSentinelFlow, @@ -708,7 +1795,7 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A } urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) if err != nil { return "", err } @@ -724,16 +1811,6 @@ func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *A return sentinel, nil } -func soraRandChoice(items []string) string { - if len(items) == 0 { - return "" - } - soraRandMu.Lock() - idx := soraRand.Intn(len(items)) - soraRandMu.Unlock() - return items[idx] -} - func soraGetPowToken(userAgent string) string { configList := soraBuildPowConfig(userAgent) seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) @@ -748,14 +1825,26 @@ func soraRandFloat() float64 { return soraRand.Float64() } +func soraRandInt(max int) int { + if max <= 1 { + return 0 + } + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Intn(max) +} + func soraBuildPowConfig(userAgent string) []any { - screen := soraRandChoice([]string{ - strconv.Itoa(1920 + 1080), - strconv.Itoa(2560 + 1440), - strconv.Itoa(1920 + 1200), - strconv.Itoa(2560 + 1600), - }) - screenVal, _ := strconv.Atoi(screen) + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" && len(soraDesktopUserAgents) > 0 { + userAgent = soraDesktopUserAgents[0] + } + screenVal := soraStableChoiceInt([]int{ + 1920 + 1080, + 2560 + 1440, + 1920 + 1200, + 2560 + 1600, + }, userAgent+"|screen") perfMs := float64(time.Since(soraPerfStart).Milliseconds()) wallMs := float64(time.Now().UnixNano()) / 1e6 diff := wallMs - perfMs @@ -765,32 +1854,47 @@ func soraBuildPowConfig(userAgent string) []any { 4294705152, 0, userAgent, - soraRandChoice(soraPowScripts), - soraRandChoice(soraPowDPL), + soraStableChoice(soraPowScripts, userAgent+"|script"), + soraStableChoice(soraPowDPL, userAgent+"|dpl"), "en-US", "en-US,es-US,en,es", 0, - soraRandChoice(soraPowNavigatorKeys), - soraRandChoice(soraPowDocumentKeys), - soraRandChoice(soraPowWindowKeys), + soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"), + soraStableChoice(soraPowDocumentKeys, userAgent+"|document"), + soraStableChoice(soraPowWindowKeys, userAgent+"|window"), perfMs, uuid.NewString(), "", - soraRandChoiceInt(soraPowCores), + soraStableChoiceInt(soraPowCores, userAgent+"|cores"), diff, } } -func soraRandChoiceInt(items []int) int { +func soraStableChoice(items []string, seed string) string { + if len(items) == 0 { + return "" + } + idx := soraStableIndex(seed, len(items)) + return items[idx] +} + +func soraStableChoiceInt(items []int, seed string) int { if len(items) == 0 { return 0 } - soraRandMu.Lock() - idx := soraRand.Intn(len(items)) - soraRandMu.Unlock() + idx := soraStableIndex(seed, len(items)) return items[idx] } +func soraStableIndex(seed string, size int) int { + if size <= 0 { + return 0 + } + h := fnv.New32a() + _, _ = h.Write([]byte(seed)) + return int(h.Sum32() % uint32(size)) +} + func soraPowParseTime() string { loc := time.FixedZone("EST", -5*3600) return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") @@ -890,6 +1994,55 @@ func hexDecodeString(s string) ([]byte, error) { return dst, err } +func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context { + if ctx == nil { + ctx = context.Background() + } + if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" { + return ctx + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano()) + trace := &soraRequestTrace{ + ID: "sora-" + soraHashForLog(seed), + ProxyKey: normalizeSoraProxyKey(proxyURL), + UAHash: soraHashForLog(strings.TrimSpace(userAgent)), + } + return context.WithValue(ctx, soraRequestTraceContextKey{}, trace) +} + +func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) { + proxyKey := normalizeSoraProxyKey(proxyURL) + uaHash := soraHashForLog(strings.TrimSpace(userAgent)) + traceID := "" + if ctx != nil { + if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil { + if strings.TrimSpace(trace.ID) != "" { + traceID = strings.TrimSpace(trace.ID) + } + if strings.TrimSpace(trace.ProxyKey) != "" { + proxyKey = strings.TrimSpace(trace.ProxyKey) + } + if strings.TrimSpace(trace.UAHash) != "" { + uaHash = strings.TrimSpace(trace.UAHash) + } + } + } + if traceID == "" { + traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano())) + } + return traceID, proxyKey, uaHash +} + +func soraHashForLog(raw string) string { + h := fnv.New32a() + _, _ = h.Write([]byte(raw)) + return fmt.Sprintf("%08x", h.Sum32()) +} + func sanitizeSoraLogURL(raw string) string { parsed, err := url.Parse(raw) if err != nil { @@ -901,3 +2054,70 @@ func sanitizeSoraLogURL(raw string) string { parsed.RawQuery = q.Encode() return parsed.String() } + +func (c *SoraDirectClient) debugEnabled() bool { + return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug +} + +func (c *SoraDirectClient) debugLogf(format string, args ...any) { + if !c.debugEnabled() { + return + } + log.Printf("[SoraClient] "+format, args...) +} + +func formatSoraHeaders(headers http.Header) string { + if len(headers) == 0 { + return "{}" + } + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + out := make(map[string]string, len(keys)) + for _, key := range keys { + values := headers.Values(key) + if len(values) == 0 { + continue + } + val := strings.Join(values, ",") + if isSensitiveHeader(key) { + out[key] = "***" + continue + } + out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160) + } + encoded, err := json.Marshal(out) + if err != nil { + return "{}" + } + return string(encoded) +} + +func isSensitiveHeader(key string) bool { + k := strings.ToLower(strings.TrimSpace(key)) + switch k { + case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key": + return true + default: + return false + } +} + +func summarizeSoraResponseBody(body []byte, maxLen int) string { + if len(body) == 0 { + return "" + } + var text string + if json.Valid(body) { + text = logredact.RedactJSON(body) + } else { + text = logredact.RedactText(string(body)) + } + text = strings.TrimSpace(text) + if maxLen <= 0 || len(text) <= maxLen { + return text + } + return text[:maxLen] + "...(truncated)" +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index a6bf71cd..cffe8a35 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -4,9 +4,16 @@ package service import ( "context" + "encoding/base64" + "encoding/json" + "errors" + "io" "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" @@ -85,3 +92,984 @@ func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { require.Equal(t, "completed", status.Status) require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) } + +func TestNormalizeSoraBaseURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "append_backend_for_sora_host", + raw: "https://sora.chatgpt.com", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "convert_backend_api_to_backend", + raw: "https://sora.chatgpt.com/backend-api", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_backend", + raw: "https://sora.chatgpt.com/backend", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_custom_host", + raw: "https://example.com/custom-path", + want: "https://example.com/custom-path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeSoraBaseURL(tt.raw) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) { + t.Parallel() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen")) +} + +func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) { + t.Parallel() + client := NewSoraDirectClient(&config.Config{}, nil, nil) + + err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen") + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url") + + errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen") + require.ErrorAs(t, errNoHint, &upstreamErr) + require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url") +} + +func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) { + t.Parallel() + headers := http.Header{} + headers.Set("Authorization", "Bearer secret-token") + headers.Set("openai-sentinel-token", "sentinel-secret") + headers.Set("X-Test", "ok") + + out := formatSoraHeaders(headers) + require.Contains(t, out, `"Authorization":"***"`) + require.Contains(t, out, `Sentinel-Token":"***"`) + require.Contains(t, out, `"X-Test":"ok"`) + require.NotContains(t, out, "secret-token") + require.NotContains(t, out, "sentinel-secret") +} + +func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) { + t.Parallel() + body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`) + out := summarizeSoraResponseBody(body, 512) + require.Contains(t, out, `"access_token":"***"`) + require.NotContains(t, out, "abc123") +} + +func TestSummarizeSoraResponseBody_Truncates(t *testing.T) { + t.Parallel() + body := []byte(strings.Repeat("x", 100)) + out := summarizeSoraResponseBody(body, 10) + require.Contains(t, out, "(truncated)") +} + +func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sora-credential-token", token) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled)) +} + +func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + account := &Account{ + ID: 2, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + cache.tokens[OpenAITokenCacheKey(account)] = "provider-token" + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + UseOpenAITokenProvider: true, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "provider-token", token) + require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0)) +} + +func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "accessToken": "session-access-token", + "expires": "2099-01-01T00:00:00Z", + }) + })) + defer server.Close() + + origin := soraSessionAuthURL + soraSessionAuthURL = server.URL + defer func() { soraSessionAuthURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 10, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "session_token": "session-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "session-access-token", token) + require.Equal(t, "session-access-token", account.GetCredential("access_token")) +} + +func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/oauth/token", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + require.NoError(t, r.ParseForm()) + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) + require.NotEmpty(t, r.FormValue("client_id")) + require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "refresh-access-token", + "refresh_token": "refresh-token-new", + "expires_in": 3600, + }) + })) + defer server.Close() + + origin := soraOAuthTokenURL + soraOAuthTokenURL = server.URL + "/oauth/token" + defer func() { soraOAuthTokenURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 11, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "refresh-token-old", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refresh-access-token", token) + require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token")) + require.NotNil(t, account.GetCredentialAsTime("expires_at")) +} + +func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Equal(t, "/nf/check", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "rate_limit_and_credit_balance": map[string]any{ + "estimated_num_videos_remaining": 0, + "rate_limit_reached": true, + }, + }) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ + ID: 12, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ok", + "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339), + }, + } + err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"}) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) +} + +func TestShouldAttemptSoraTokenRecover(t *testing.T) { + t.Parallel() + + require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen")) + require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) +} + +type soraClientRequestCall struct { + Path string + UserAgent string + ProxyURL string +} + +type soraClientRecordingUpstream struct { + calls []soraClientRequestCall +} + +func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) { + u.calls = append(u.calls, soraClientRequestCall{ + Path: req.URL.Path, + UserAgent: req.Header.Get("User-Agent"), + ProxyURL: proxyURL, + }) + switch req.URL.Path { + case "/backend-api/sentinel/req": + return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil + case "/backend/nf/create": + return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil + case "/backend/nf/create/storyboard": + return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil + case "/backend/uploads": + return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil + case "/backend/nf/check": + return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil + case "/backend/characters/upload": + return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil + case "/backend/project_y/cameos/in_progress/cameo-123": + return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil + case "/backend/project_y/file/upload": + return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil + case "/backend/characters/finalize": + return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil + case "/backend/project_y/post": + return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil + default: + return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil + } +} + +func newSoraClientMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) { + client := NewSoraDirectClient(&config.Config{}, nil, nil) + ua := client.taskUserAgent() + require.NotEmpty(t, ua) + allowed := append([]string{}, soraMobileUserAgents...) + allowed = append(allowed, soraDesktopUserAgents...) + require.Contains(t, allowed, ua) +} + +func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { + soraPowTokenGenerator = originPowTokenGenerator + }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(9) + account := &Account{ + ID: 21, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"}) + require.NoError(t, err) + require.Equal(t, "task-123", taskID) + require.Len(t, upstream.calls, 2) + + sentinelCall := upstream.calls[0] + createCall := upstream.calls[1] + require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path) + require.Equal(t, "/backend/nf/create", createCall.Path) + require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL) + require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL) + require.NotEmpty(t, sentinelCall.UserAgent) + require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent) +} + +func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(3) + account := &Account{ + ID: 31, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png") + require.NoError(t, err) + require.Equal(t, "upload-123", uploadID) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/uploads", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.NotEmpty(t, upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(7) + account := &Account{ + ID: 41, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"}) + require.NoError(t, err) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/nf/check", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.NotEmpty(t, upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 51, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{ + Prompt: "Shot 1:\nduration: 5sec\nScene: cat", + }) + require.NoError(t, err) + require.Equal(t, "storyboard-123", taskID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path) +} + +func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/nf/pending/v2": + _, _ = w.Write([]byte(`[]`)) + case "/project_y/profile/drafts": + _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{Credentials: map[string]any{"access_token": "token"}} + + status, err := client.GetVideoTask(context.Background(), account, "task-1") + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + require.Equal(t, "gen_1", status.GenerationID) + require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs) +} + +func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 52, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1") + require.NoError(t, err) + require.Equal(t, "s_post", postID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path) +} + +type soraClientFallbackUpstream struct { + doWithTLSCalls int32 + respBody string + respStatusCode int + err error +} + +func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + atomic.AddInt32(&u.doWithTLSCalls, 1) + if u.err != nil { + return nil, u.err + } + statusCode := u.respStatusCode + if statusCode <= 0 { + statusCode = http.StatusOK + } + body := u.respBody + if body == "" { + body = `{"ok":true}` + } + return newSoraClientMockResponse(statusCode, body), nil +} + +func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) { + var captured soraCurlCFFISidecarRequest + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/request", r.URL.Path) + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(raw, &captured)) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "Content-Type": "application/json", + "X-Sidecar": []string{"yes"}, + }, + "body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)), + }) + })) + defer sidecar.Close() + + upstream := &soraClientFallbackUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + TimeoutSeconds: 15, + SessionReuseEnabled: true, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar")) + require.NoError(t, err) + req.Header.Set("User-Agent", "test-ua") + + resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1}) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.JSONEq(t, `{"ok":true}`, string(body)) + require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) + require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL) + require.NotEmpty(t, captured.SessionKey) + require.Equal(t, "chrome131", captured.Impersonate) + require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL) + decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64) + require.NoError(t, err) + require.Equal(t, "hello-sidecar", string(decodedReqBody)) +} + +func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) { + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"boom"}`)) + })) + defer sidecar.Close() + + upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + + _, err = client.doHTTP(req, "", &Account{ID: 2}) + require.Error(t, err) + require.Contains(t, err.Error(), "sora curl_cffi sidecar") + require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) +} + +func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) { + upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: false, + BaseURL: "http://127.0.0.1:18080", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + + resp, err := client.doHTTP(req, "", &Account{ID: 3}) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.JSONEq(t, `{"legacy":true}`, string(body)) + require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls)) +} + +func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) { + require.Nil(t, convertSidecarHeaderValue(nil)) + require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"})) +} + +func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) { + var captured []soraCurlCFFISidecarRequest + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + var reqPayload soraCurlCFFISidecarRequest + require.NoError(t, json.Unmarshal(raw, &reqPayload)) + captured = append(captured, reqPayload) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "Content-Type": "application/json", + }, + "body": `{"ok":true}`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 1001} + + req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + _, err = client.doHTTP(req1, "http://127.0.0.1:18080", account) + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + _, err = client.doHTTP(req2, "http://127.0.0.1:18080", account) + require.NoError(t, err) + + require.Len(t, captured, 2) + require.NotEmpty(t, captured[0].SessionKey) + require.Equal(t, captured[0].SessionKey, captured[1].SessionKey) +} + +func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) { + var sidecarCalls int32 + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&sidecarCalls, 1) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusForbidden, + "headers": map[string]any{ + "cf-ray": "9d05d73dec4d8c8e-GRU", + "content-type": "text/html", + }, + "body": `Just a moment...`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + MaxRetries: 3, + CloudflareChallengeCooldownSeconds: 60, + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + headers := http.Header{} + + _, _, err := client.doRequestWithProxy( + context.Background(), + &Account{ID: 99}, + "http://127.0.0.1:18080", + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode) + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry") + + _, _, err = client.doRequestWithProxy( + context.Background(), + &Account{ID: 99}, + "http://127.0.0.1:18080", + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.Error(t, err) + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) + require.Contains(t, upstreamErr.Message, "cooling down") + require.Contains(t, upstreamErr.Message, "cf-ray") + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request") +} + +func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) { + var sidecarCalls int32 + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := atomic.AddInt32(&sidecarCalls, 1) + if call == 1 { + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusForbidden, + "headers": map[string]any{ + "cf-ray": "9d05d73dec4d8c8e-GRU", + "content-type": "text/html", + }, + "body": `Just a moment...`, + }) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "content-type": "application/json", + }, + "body": `{"ok":true}`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + MaxRetries: 3, + CloudflareChallengeCooldownSeconds: 60, + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + headers := http.Header{} + account := &Account{ID: 109} + proxyURL := "http://127.0.0.1:18080" + + body, _, err := client.doRequestWithProxy( + context.Background(), + account, + proxyURL, + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.NoError(t, err) + require.Contains(t, string(body), `"ok":true`) + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls)) + + _, _, err = client.doRequestWithProxy( + context.Background(), + account, + proxyURL, + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.NoError(t, err) + require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds") +} + +func TestSoraComputeChallengeCooldownSeconds(t *testing.T) { + require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3)) + require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1)) + require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2)) + require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4)) + require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4") + require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s") +} + +func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CloudflareChallengeCooldownSeconds: 10, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 201} + proxyURL := "http://127.0.0.1:18080" + + client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil) + client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil) + + key := soraAccountProxyKey(account, proxyURL) + entry, ok := client.challengeCooldowns[key] + require.True(t, ok) + require.Equal(t, 2, entry.ConsecutiveChallenges) + require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay) + remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds()) + require.GreaterOrEqual(t, remain, 19) +} + +func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080")) + require.Empty(t, client.sidecarSessions) +} + +func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 123} + key := soraAccountProxyKey(account, "http://127.0.0.1:18080") + client.sidecarSessions[key] = soraSidecarSessionEntry{ + SessionKey: "sora-expired", + ExpiresAt: time.Now().Add(-time.Minute), + LastUsedAt: time.Now().Add(-2 * time.Minute), + } + + sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + require.NotEmpty(t, sessionKey) + require.NotEqual(t, "sora-expired", sessionKey) + require.Len(t, client.sidecarSessions, 1) +} + +func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 0, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 456} + + first := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + second := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + require.NotEmpty(t, first) + require.Equal(t, first, second) + + key := soraAccountProxyKey(account, "http://127.0.0.1:18080") + entry, ok := client.sidecarSessions[key] + require.True(t, ok) + require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour))) +} diff --git a/backend/internal/service/sora_curl_cffi_sidecar.go b/backend/internal/service/sora_curl_cffi_sidecar.go new file mode 100644 index 00000000..40f5c017 --- /dev/null +++ b/backend/internal/service/sora_curl_cffi_sidecar.go @@ -0,0 +1,260 @@ +package service + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const soraCurlCFFISidecarDefaultTimeoutSeconds = 60 + +type soraCurlCFFISidecarRequest struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string][]string `json:"headers,omitempty"` + BodyBase64 string `json:"body_base64,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + SessionKey string `json:"session_key,omitempty"` + Impersonate string `json:"impersonate,omitempty"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` +} + +type soraCurlCFFISidecarResponse struct { + StatusCode int `json:"status_code"` + Status int `json:"status"` + Headers map[string]any `json:"headers"` + BodyBase64 string `json:"body_base64"` + Body string `json:"body"` + Error string `json:"error"` +} + +func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + if req == nil || req.URL == nil { + return nil, errors.New("request url is nil") + } + if c == nil || c.cfg == nil { + return nil, errors.New("sora curl_cffi sidecar config is nil") + } + if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled { + return nil, errors.New("sora curl_cffi sidecar is disabled") + } + endpoint := c.curlCFFISidecarEndpoint() + if endpoint == "" { + return nil, errors.New("sora curl_cffi sidecar base_url is empty") + } + + bodyBytes, err := readAndRestoreRequestBody(req) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err) + } + + headers := make(map[string][]string, len(req.Header)+1) + for key, vals := range req.Header { + copied := make([]string, len(vals)) + copy(copied, vals) + headers[key] = copied + } + if strings.TrimSpace(req.Host) != "" { + if _, ok := headers["Host"]; !ok { + headers["Host"] = []string{req.Host} + } + } + + payload := soraCurlCFFISidecarRequest{ + Method: req.Method, + URL: req.URL.String(), + Headers: headers, + ProxyURL: strings.TrimSpace(proxyURL), + SessionKey: c.sidecarSessionKey(account, proxyURL), + Impersonate: c.curlCFFIImpersonate(), + TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(), + } + if len(bodyBytes) > 0 { + payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes) + } + + encoded, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err) + } + + sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err) + } + sidecarReq.Header.Set("Content-Type", "application/json") + sidecarReq.Header.Set("Accept", "application/json") + + httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second} + sidecarResp, err := httpClient.Do(sidecarReq) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err) + } + defer func() { + _ = sidecarResp.Body.Close() + }() + + sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20)) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err) + } + if sidecarResp.StatusCode != http.StatusOK { + redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512) + return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted) + } + + var payloadResp soraCurlCFFISidecarResponse + if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err) + } + if msg := strings.TrimSpace(payloadResp.Error); msg != "" { + return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg) + } + statusCode := payloadResp.StatusCode + if statusCode <= 0 { + statusCode = payloadResp.Status + } + if statusCode <= 0 { + return nil, errors.New("sora curl_cffi sidecar response missing status code") + } + + responseBody := []byte(payloadResp.Body) + if strings.TrimSpace(payloadResp.BodyBase64) != "" { + decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err) + } + responseBody = decoded + } + + respHeaders := make(http.Header) + for key, rawVal := range payloadResp.Headers { + for _, v := range convertSidecarHeaderValue(rawVal) { + respHeaders.Add(key, v) + } + } + + return &http.Response{ + StatusCode: statusCode, + Header: respHeaders, + Body: io.NopCloser(bytes.NewReader(responseBody)), + ContentLength: int64(len(responseBody)), + Request: req, + }, nil +} + +func readAndRestoreRequestBody(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.ContentLength = int64(len(bodyBytes)) + return bodyBytes, nil +} + +func (c *SoraDirectClient) curlCFFISidecarEndpoint() string { + if c == nil || c.cfg == nil { + return "" + } + raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL) + if raw == "" { + return "" + } + parsed, err := url.Parse(raw) + if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" { + return raw + } + if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" { + parsed.Path = "/request" + } + return parsed.String() +} + +func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int { + if c == nil || c.cfg == nil { + return soraCurlCFFISidecarDefaultTimeoutSeconds + } + timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds + if timeoutSeconds <= 0 { + return soraCurlCFFISidecarDefaultTimeoutSeconds + } + return timeoutSeconds +} + +func (c *SoraDirectClient) curlCFFIImpersonate() string { + if c == nil || c.cfg == nil { + return "chrome131" + } + impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate) + if impersonate == "" { + return "chrome131" + } + return impersonate +} + +func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool { + if c == nil || c.cfg == nil { + return true + } + return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled +} + +func (c *SoraDirectClient) sidecarSessionTTLSeconds() int { + if c == nil || c.cfg == nil { + return 3600 + } + ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds + if ttl < 0 { + return 3600 + } + return ttl +} + +func convertSidecarHeaderValue(raw any) []string { + switch val := raw.(type) { + case nil: + return nil + case string: + if strings.TrimSpace(val) == "" { + return nil + } + return []string{val} + case []any: + out := make([]string, 0, len(val)) + for _, item := range val { + s := strings.TrimSpace(fmt.Sprint(item)) + if s != "" { + out = append(out, s) + } + } + return out + case []string: + out := make([]string, 0, len(val)) + for _, item := range val { + if strings.TrimSpace(item) != "" { + out = append(out, item) + } + } + return out + default: + s := strings.TrimSpace(fmt.Sprint(val)) + if s == "" { + return nil + } + return []string{s} + } +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index d7ff297c..ac29ae0d 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -8,10 +8,12 @@ import ( "fmt" "io" "log" + "math" "mime" "net" "net/http" "net/url" + "regexp" "strconv" "strings" "time" @@ -23,6 +25,9 @@ import ( const soraImageInputMaxBytes = 20 << 20 const soraImageInputMaxRedirects = 3 const soraImageInputTimeout = 20 * time.Second +const soraVideoInputMaxBytes = 200 << 20 +const soraVideoInputMaxRedirects = 3 +const soraVideoInputTimeout = 60 * time.Second var soraImageSizeMap = map[string]string{ "gpt-image": "360", @@ -61,6 +66,36 @@ type SoraGatewayService struct { cfg *config.Config } +type soraWatermarkOptions struct { + Enabled bool + ParseMethod string + ParseURL string + ParseToken string + FallbackOnFailure bool + DeletePost bool +} + +type soraCharacterOptions struct { + SetPublic bool + DeleteAfterGenerate bool +} + +type soraCharacterFlowResult struct { + CameoID string + CharacterID string + Username string + DisplayName string +} + +var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) +var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) +var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) +var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) + +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + func NewSoraGatewayService( soraClient SoraClient, mediaStorage *SoraMediaStorage, @@ -112,29 +147,133 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) return nil, fmt.Errorf("unsupported model: %s", reqModel) } - if modelCfg.Type == "prompt_enhance" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Prompt-enhance 模型暂未支持", clientStream) - return nil, fmt.Errorf("prompt-enhance not supported") - } - prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) - if strings.TrimSpace(prompt) == "" { + prompt = strings.TrimSpace(prompt) + imageInput = strings.TrimSpace(imageInput) + videoInput = strings.TrimSpace(videoInput) + remixTargetID = strings.TrimSpace(remixTargetID) + + if videoInput != "" && modelCfg.Type != "video" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) + return nil, errors.New("video input only supports video models") + } + if videoInput != "" && imageInput != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) + return nil, errors.New("image input and video input cannot be used together") + } + characterOnly := videoInput != "" && prompt == "" + if modelCfg.Type == "prompt_enhance" && prompt == "" { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) return nil, errors.New("prompt is required") } - if strings.TrimSpace(videoInput) != "" { - s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Video input is not supported yet", clientStream) - return nil, errors.New("video input not supported") + if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") } reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) if cancel != nil { defer cancel() } + if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + + characterOpts := parseSoraCharacterOptions(reqBody) + watermarkOpts := parseSoraWatermarkOptions(reqBody) + var characterResult *soraCharacterFlowResult + if videoInput != "" { + videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) + if videoErr != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) + return nil, videoErr + } + characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) + if videoErr != nil { + return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) + } + if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { + characterID := strings.TrimSpace(characterResult.CharacterID) + defer func() { + cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) + defer cancelCleanup() + if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { + log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) + } + }() + } + if characterOnly { + content := "角色创建成功" + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + resp := buildSoraNonStreamResponse(content, reqModel) + if characterResult != nil { + resp["character_id"] = characterResult.CharacterID + resp["cameo_id"] = characterResult.CameoID + resp["character_username"] = characterResult.Username + resp["character_display_name"] = characterResult.DisplayName + } + c.JSON(http.StatusOK, resp) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) + } + } var imageData []byte imageFilename := "" - if strings.TrimSpace(imageInput) != "" { + if imageInput != "" { decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) if err != nil { s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) @@ -164,15 +303,27 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun MediaID: mediaID, }) case "video": - taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ - Prompt: prompt, - Orientation: modelCfg.Orientation, - Frames: modelCfg.Frames, - Model: modelCfg.Model, - Size: modelCfg.Size, - MediaID: mediaID, - RemixTargetID: remixTargetID, - }) + if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { + taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ + Prompt: formatSoraStoryboardPrompt(prompt), + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + }) + } else { + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + RemixTargetID: remixTargetID, + CameoIDs: extractSoraCameoIDs(reqBody), + }) + } default: err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) } @@ -185,6 +336,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } var mediaURLs []string + videoGenerationID := "" mediaType := modelCfg.Type imageCount := 0 imageSize := "" @@ -198,15 +350,32 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun imageCount = len(urls) imageSize = soraImageSizeFromModel(reqModel) case "video": - urls, pollErr := s.pollVideoTask(reqCtx, c, account, taskID, clientStream) + videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) if pollErr != nil { return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) } - mediaURLs = urls + if videoStatus != nil { + mediaURLs = videoStatus.URLs + videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) + } default: mediaType = "prompt" } + watermarkPostID := "" + if modelCfg.Type == "video" && watermarkOpts.Enabled { + watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) + if watermarkErr != nil { + if !watermarkOpts.FallbackOnFailure { + return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) + } + log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) + } else if strings.TrimSpace(watermarkURL) != "" { + mediaURLs = []string{strings.TrimSpace(watermarkURL)} + watermarkPostID = strings.TrimSpace(postID) + } + } + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) @@ -217,6 +386,11 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun finalURLs = s.normalizeSoraMediaURLs(stored) } } + if watermarkPostID != "" && watermarkOpts.DeletePost { + if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { + log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) + } + } content := buildSoraContent(mediaType, finalURLs) var firstTokenMs *int @@ -265,9 +439,270 @@ func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) ( return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) } +func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { + opts := soraWatermarkOptions{ + Enabled: parseBoolWithDefault(body, "watermark_free", false), + ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), + ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), + ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), + FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), + DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), + } + if opts.ParseMethod == "" { + opts.ParseMethod = "third_party" + } + return opts +} + +func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { + return soraCharacterOptions{ + SetPublic: parseBoolWithDefault(body, "character_set_public", true), + DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), + } +} + +func parseBoolWithDefault(body map[string]any, key string, def bool) bool { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case bool: + return typed + case int: + return typed != 0 + case int32: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case string: + typed = strings.ToLower(strings.TrimSpace(typed)) + if typed == "true" || typed == "1" || typed == "yes" { + return true + } + if typed == "false" || typed == "0" || typed == "no" { + return false + } + } + return def +} + +func parseStringWithDefault(body map[string]any, key, def string) string { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + if str, ok := val.(string); ok { + return str + } + return def +} + +func extractSoraCameoIDs(body map[string]any) []string { + if body == nil { + return nil + } + raw, ok := body["cameo_ids"] + if !ok { + return nil + } + switch typed := raw.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + continue + } + str = strings.TrimSpace(str) + if str != "" { + out = append(out, str) + } + } + return out + default: + return nil + } +} + +func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { + cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) + if err != nil { + return nil, err + } + + cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) + if err != nil { + return nil, err + } + username := processSoraCharacterUsername(cameoStatus.UsernameHint) + displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) + if displayName == "" { + displayName = "Character" + } + profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) + if profileAssetURL == "" { + return nil, errors.New("profile asset url not found in cameo status") + } + + avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) + if err != nil { + return nil, err + } + assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) + if err != nil { + return nil, err + } + instructionSet := cameoStatus.InstructionSetHint + if instructionSet == nil { + instructionSet = cameoStatus.InstructionSet + } + + characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ + CameoID: strings.TrimSpace(cameoID), + Username: username, + DisplayName: displayName, + ProfileAssetPointer: assetPointer, + InstructionSet: instructionSet, + }) + if err != nil { + return nil, err + } + + if opts.SetPublic { + if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { + return nil, err + } + } + + return &soraCharacterFlowResult{ + CameoID: strings.TrimSpace(cameoID), + CharacterID: strings.TrimSpace(characterID), + Username: strings.TrimSpace(username), + DisplayName: displayName, + }, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + timeout := 10 * time.Minute + interval := 5 * time.Second + maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var lastErr error + consecutiveErrors := 0 + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) + if err != nil { + lastErr = err + consecutiveErrors++ + if consecutiveErrors >= 3 { + break + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + consecutiveErrors = 0 + if status == nil { + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) + statusMessage := strings.TrimSpace(status.StatusMessage) + if currentStatus == "failed" { + if statusMessage == "" { + statusMessage = "character creation failed" + } + return nil, errors.New(statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { + return status, nil + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + } + if lastErr != nil { + return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) + } + return nil, errors.New("cameo processing timeout") +} + +func processSoraCharacterUsername(usernameHint string) string { + usernameHint = strings.TrimSpace(usernameHint) + if usernameHint == "" { + usernameHint = "character" + } + if strings.Contains(usernameHint, ".") { + parts := strings.Split(usernameHint, ".") + usernameHint = strings.TrimSpace(parts[len(parts)-1]) + } + if usernameHint == "" { + usernameHint = "character" + } + return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100) +} + +func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { + generationID = strings.TrimSpace(generationID) + if generationID == "" { + return "", "", errors.New("generation id is required for watermark-free mode") + } + postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) + if err != nil { + return "", "", err + } + postID = strings.TrimSpace(postID) + if postID == "" { + return "", "", errors.New("watermark-free publish returned empty post id") + } + + switch opts.ParseMethod { + case "custom": + urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) + if parseErr != nil { + return "", postID, parseErr + } + return strings.TrimSpace(urlVal), postID, nil + case "", "third_party": + return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil + default: + return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) + } +} + func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { - case 401, 402, 403, 429, 529: + case 401, 402, 403, 404, 429, 529: return true default: return statusCode >= 500 @@ -434,7 +869,18 @@ func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, } if stream { flusher, _ := c.Writer.(http.Flusher) - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) _, _ = fmt.Fprint(c.Writer, errorEvent) _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") if flusher != nil { @@ -460,7 +906,15 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { - return &UpstreamFailoverError{StatusCode: upstreamErr.StatusCode} + var responseHeaders http.Header + if upstreamErr.Headers != nil { + responseHeaders = upstreamErr.Headers.Clone() + } + return &UpstreamFailoverError{ + StatusCode: upstreamErr.StatusCode, + ResponseBody: upstreamErr.Body, + ResponseHeaders: responseHeaders, + } } msg := upstreamErr.Message if override := soraProErrorMessage(model, msg); override != "" { @@ -505,7 +959,7 @@ func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, return nil, errors.New("sora image generation timeout") } -func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { +func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { interval := s.pollInterval() maxAttempts := s.pollMaxAttempts() lastPing := time.Now() @@ -516,7 +970,7 @@ func (s *SoraGatewayService) pollVideoTask(ctx context.Context, c *gin.Context, } switch strings.ToLower(status.Status) { case "completed", "succeeded": - return status.URLs, nil + return status, nil case "failed": if status.ErrorMsg != "" { return nil, errors.New(status.ErrorMsg) @@ -620,7 +1074,7 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi return "", "", "", "" } if v, ok := body["remix_target_id"].(string); ok { - remixTargetID = v + remixTargetID = strings.TrimSpace(v) } if v, ok := body["image"].(string); ok { imageInput = v @@ -661,6 +1115,10 @@ func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remi prompt = builder.String() } } + if remixTargetID == "" { + remixTargetID = extractRemixTargetIDFromPrompt(prompt) + } + prompt = cleanRemixLinkFromPrompt(prompt) return prompt, imageInput, videoInput, remixTargetID } @@ -708,6 +1166,69 @@ func parseSoraMessageContent(content any) (text, imageInput, videoInput string) } } +func isSoraStoryboardPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return false + } + return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 +} + +func formatSoraStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstBracketPos := strings.Index(prompt, "[") + instructions := "" + if firstBracketPos > 0 { + instructions = strings.TrimSpace(prompt[:firstBracketPos]) + } + shots := make([]string, 0, len(matches)) + for i, match := range matches { + if len(match) < 3 { + continue + } + duration := strings.TrimSpace(match[1]) + scene := strings.TrimSpace(match[2]) + if scene == "" { + continue + } + shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) + } + if len(shots) == 0 { + return prompt + } + timeline := strings.Join(shots, "\n\n") + if instructions == "" { + return timeline + } + return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) +} + +func extractRemixTargetIDFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) +} + +func cleanRemixLinkFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") + cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") + cleaned = strings.Join(strings.Fields(cleaned), " ") + return strings.TrimSpace(cleaned) +} + func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { raw := strings.TrimSpace(input) if raw == "" { @@ -720,7 +1241,7 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er } meta := parts[0] payload := parts[1] - decoded, err := base64.StdEncoding.DecodeString(payload) + decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) if err != nil { return nil, "", err } @@ -739,15 +1260,47 @@ func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, er if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { return downloadSoraImageInput(ctx, raw) } - decoded, err := base64.StdEncoding.DecodeString(raw) + decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) if err != nil { return nil, "", errors.New("invalid base64 image") } return decoded, "image.png", nil } +func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, errors.New("empty video input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, errors.New("invalid video data url") + } + decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraVideoInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil +} + func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { - parsed, err := validateSoraImageURL(rawURL) + parsed, err := validateSoraRemoteURL(rawURL) if err != nil { return nil, "", err } @@ -761,7 +1314,7 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, if len(via) >= soraImageInputMaxRedirects { return errors.New("too many redirects") } - return validateSoraImageURLValue(req.URL) + return validateSoraRemoteURLValue(req.URL) }, } resp, err := client.Do(req) @@ -784,51 +1337,103 @@ func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, return data, filename, nil } -func validateSoraImageURL(raw string) (*url.URL, error) { +func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: soraVideoInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraVideoInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("empty video content") + } + return data, nil +} + +func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + return nil, errors.New("invalid max bytes limit") + } + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + limited := io.LimitReader(decoder, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) + } + return data, nil +} + +func validateSoraRemoteURL(raw string) (*url.URL, error) { if strings.TrimSpace(raw) == "" { - return nil, errors.New("empty image url") + return nil, errors.New("empty remote url") } parsed, err := url.Parse(raw) if err != nil { - return nil, fmt.Errorf("invalid image url: %w", err) + return nil, fmt.Errorf("invalid remote url: %w", err) } - if err := validateSoraImageURLValue(parsed); err != nil { + if err := validateSoraRemoteURLValue(parsed); err != nil { return nil, err } return parsed, nil } -func validateSoraImageURLValue(parsed *url.URL) error { +func validateSoraRemoteURLValue(parsed *url.URL) error { if parsed == nil { - return errors.New("invalid image url") + return errors.New("invalid remote url") } scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) if scheme != "http" && scheme != "https" { - return errors.New("only http/https image url is allowed") + return errors.New("only http/https remote url is allowed") } if parsed.User != nil { - return errors.New("image url cannot contain userinfo") + return errors.New("remote url cannot contain userinfo") } host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) if host == "" { - return errors.New("image url missing host") + return errors.New("remote url missing host") } if _, blocked := soraBlockedHostnames[host]; blocked { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } if ip := net.ParseIP(host); ip != nil { if isSoraBlockedIP(ip) { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } return nil } ips, err := net.LookupIP(host) if err != nil { - return fmt.Errorf("resolve image url failed: %w", err) + return fmt.Errorf("resolve remote url failed: %w", err) } for _, ip := range ips { if isSoraBlockedIP(ip) { - return errors.New("image url is not allowed") + return errors.New("remote url is not allowed") } } return nil diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index d6bf9eae..5888fe92 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -4,10 +4,16 @@ package service import ( "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -18,6 +24,13 @@ type stubSoraClientForPoll struct { videoStatus *SoraVideoTaskStatus imageCalls int videoCalls int + enhanced string + enhanceErr error + storyboard bool + videoReq SoraVideoRequest + parseErr error + postCalls int + deleteCalls int } func (s *stubSoraClientForPoll) Enabled() bool { return true } @@ -28,8 +41,60 @@ func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Ac return "task-image", nil } func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + s.videoReq = req return "task-video", nil } +func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + s.storyboard = true + return "task-video", nil +} +func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + return &SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + return nil +} +func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + return nil +} +func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + s.postCalls++ + return "s_post", nil +} +func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { + s.deleteCalls++ + return nil +} +func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + if s.parseErr != nil { + return "", s.parseErr + } + return "https://example.com/no-watermark.mp4", nil +} +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { s.imageCalls++ return s.imageStatus, nil @@ -62,6 +127,136 @@ func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { require.Equal(t, 1, client.imageCalls) } +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + +func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, client.storyboard) +} + +func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { + client := &stubSoraClientForPoll{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, 0, client.videoCalls) +} + +func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + parseErr: errors.New("parse failed"), + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/original.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 0, client.deleteCalls) +} + +func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 1, client.deleteCalls) +} + func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { client := &stubSoraClientForPoll{ videoStatus: &SoraVideoTaskStatus{ @@ -79,9 +274,9 @@ func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { } service := NewSoraGatewayService(client, nil, nil, cfg) - urls, err := service.pollVideoTask(context.Background(), nil, &Account{ID: 1}, "task", false) + status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) require.Error(t, err) - require.Empty(t, urls) + require.Nil(t, status) require.Contains(t, err.Error(), "reject") require.Equal(t, 1, client.videoCalls) } @@ -175,9 +370,65 @@ func TestSoraProErrorMessage(t *testing.T) { require.Empty(t, soraProErrorMessage("sora-basic", "")) } +func TestSoraGatewayService_WriteSoraError_StreamEscapesJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + svc.writeSoraError(c, http.StatusBadGateway, "upstream_error", "invalid \"prompt\"\nline2", true) + + body := rec.Body.String() + require.Contains(t, body, "event: error\n") + require.Contains(t, body, "data: [DONE]\n\n") + + lines := strings.Split(body, "\n") + require.GreaterOrEqual(t, len(lines), 2) + require.Equal(t, "event: error", lines[0]) + require.True(t, strings.HasPrefix(lines[1], "data: ")) + + data := strings.TrimPrefix(lines[1], "data: ") + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(data), &parsed)) + errObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errObj["type"]) + require.Equal(t, "invalid \"prompt\"\nline2", errObj["message"]) +} + +func TestSoraGatewayService_HandleSoraRequestError_FailoverHeadersCloned(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + sourceHeaders := http.Header{} + sourceHeaders.Set("cf-ray", "9d01b0e9ecc35829-SEA") + + err := svc.handleSoraRequestError( + context.Background(), + &Account{ID: 1, Platform: PlatformSora}, + &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "forbidden", + Headers: sourceHeaders, + Body: []byte(`Just a moment...`), + }, + "sora2-landscape-10s", + nil, + false, + ) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.NotNil(t, failoverErr.ResponseHeaders) + require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray")) + + sourceHeaders.Set("cf-ray", "mutated-after-return") + require.Equal(t, "9d01b0e9ecc35829-SEA", failoverErr.ResponseHeaders.Get("cf-ray")) +} + func TestShouldFailoverUpstreamError(t *testing.T) { svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) require.True(t, svc.shouldFailoverUpstreamError(401)) + require.True(t, svc.shouldFailoverUpstreamError(404)) require.True(t, svc.shouldFailoverUpstreamError(429)) require.True(t, svc.shouldFailoverUpstreamError(500)) require.True(t, svc.shouldFailoverUpstreamError(502)) @@ -257,3 +508,19 @@ func TestDecodeSoraImageInput_DataURL(t *testing.T) { require.NotEmpty(t, data) require.Contains(t, filename, ".png") } + +func TestDecodeBase64WithLimit_ExceedLimit(t *testing.T) { + data, err := decodeBase64WithLimit("aGVsbG8=", 3) + require.Error(t, err) + require.Nil(t, data) +} + +func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) { + body := map[string]any{ + "watermark_free": float64(1), + "watermark_fallback_on_failure": float64(0), + } + opts := parseSoraWatermarkOptions(body) + require.True(t, opts.Enabled) + require.False(t, opts.FallbackOnFailure) +} diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go index ab095e46..80b20a4b 100644 --- a/backend/internal/service/sora_models.go +++ b/backend/internal/service/sora_models.go @@ -17,6 +17,9 @@ type SoraModelConfig struct { Model string Size string RequirePro bool + // Prompt-enhance 专用参数 + ExpansionLevel string + DurationS int } var soraModelConfigs = map[string]SoraModelConfig{ @@ -160,31 +163,49 @@ var soraModelConfigs = map[string]SoraModelConfig{ RequirePro: true, }, "prompt-enhance-short-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 10, }, "prompt-enhance-short-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 15, }, "prompt-enhance-short-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "short", + DurationS: 20, }, "prompt-enhance-medium-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 10, }, "prompt-enhance-medium-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 15, }, "prompt-enhance-medium-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "medium", + DurationS: 20, }, "prompt-enhance-long-10s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 10, }, "prompt-enhance-long-15s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 15, }, "prompt-enhance-long-20s": { - Type: "prompt_enhance", + Type: "prompt_enhance", + ExpansionLevel: "long", + DurationS: 20, }, } diff --git a/backend/internal/service/sora_request_guard.go b/backend/internal/service/sora_request_guard.go new file mode 100644 index 00000000..a118fe82 --- /dev/null +++ b/backend/internal/service/sora_request_guard.go @@ -0,0 +1,266 @@ +package service + +import ( + "fmt" + "math" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + "github.com/google/uuid" +) + +type soraChallengeCooldownEntry struct { + Until time.Time + StatusCode int + CFRay string + ConsecutiveChallenges int + LastChallengeAt time.Time +} + +type soraSidecarSessionEntry struct { + SessionKey string + ExpiresAt time.Time + LastUsedAt time.Time +} + +func (c *SoraDirectClient) cloudflareChallengeCooldownSeconds() int { + if c == nil || c.cfg == nil { + return 900 + } + cooldown := c.cfg.Sora.Client.CloudflareChallengeCooldownSeconds + if cooldown <= 0 { + return 0 + } + return cooldown +} + +func (c *SoraDirectClient) checkCloudflareChallengeCooldown(account *Account, proxyURL string) error { + if c == nil { + return nil + } + if account == nil || account.ID <= 0 { + return nil + } + cooldownSeconds := c.cloudflareChallengeCooldownSeconds() + if cooldownSeconds <= 0 { + return nil + } + key := soraAccountProxyKey(account, proxyURL) + now := time.Now() + + c.challengeCooldownMu.RLock() + entry, ok := c.challengeCooldowns[key] + c.challengeCooldownMu.RUnlock() + if !ok { + return nil + } + if !entry.Until.After(now) { + c.challengeCooldownMu.Lock() + delete(c.challengeCooldowns, key) + c.challengeCooldownMu.Unlock() + return nil + } + + remaining := int(math.Ceil(entry.Until.Sub(now).Seconds())) + if remaining < 1 { + remaining = 1 + } + message := fmt.Sprintf("Sora request cooling down due to recent Cloudflare challenge. Retry in %d seconds.", remaining) + if entry.ConsecutiveChallenges > 1 { + message = fmt.Sprintf("%s (streak=%d)", message, entry.ConsecutiveChallenges) + } + if entry.CFRay != "" { + message = fmt.Sprintf("%s (last cf-ray: %s)", message, entry.CFRay) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: message, + Headers: make(http.Header), + } +} + +func (c *SoraDirectClient) recordCloudflareChallengeCooldown(account *Account, proxyURL string, statusCode int, headers http.Header, body []byte) { + if c == nil { + return + } + if account == nil || account.ID <= 0 { + return + } + cooldownSeconds := c.cloudflareChallengeCooldownSeconds() + if cooldownSeconds <= 0 { + return + } + key := soraAccountProxyKey(account, proxyURL) + now := time.Now() + cfRay := soraerror.ExtractCloudflareRayID(headers, body) + + c.challengeCooldownMu.Lock() + c.cleanupExpiredChallengeCooldownsLocked(now) + + streak := 1 + existing, ok := c.challengeCooldowns[key] + if ok && now.Sub(existing.LastChallengeAt) <= 30*time.Minute { + streak = existing.ConsecutiveChallenges + 1 + } + effectiveCooldown := soraComputeChallengeCooldownSeconds(cooldownSeconds, streak) + until := now.Add(time.Duration(effectiveCooldown) * time.Second) + if ok && existing.Until.After(until) { + until = existing.Until + if existing.ConsecutiveChallenges > streak { + streak = existing.ConsecutiveChallenges + } + if cfRay == "" { + cfRay = existing.CFRay + } + } + c.challengeCooldowns[key] = soraChallengeCooldownEntry{ + Until: until, + StatusCode: statusCode, + CFRay: cfRay, + ConsecutiveChallenges: streak, + LastChallengeAt: now, + } + c.challengeCooldownMu.Unlock() + + if c.debugEnabled() { + remain := int(math.Ceil(until.Sub(now).Seconds())) + if remain < 0 { + remain = 0 + } + c.debugLogf("cloudflare_challenge_cooldown_set key=%s status=%d remain_s=%d streak=%d cf_ray=%s", key, statusCode, remain, streak, cfRay) + } +} + +func soraComputeChallengeCooldownSeconds(baseSeconds, streak int) int { + if baseSeconds <= 0 { + return 0 + } + if streak < 1 { + streak = 1 + } + multiplier := streak + if multiplier > 4 { + multiplier = 4 + } + cooldown := baseSeconds * multiplier + if cooldown > 3600 { + cooldown = 3600 + } + return cooldown +} + +func (c *SoraDirectClient) clearCloudflareChallengeCooldown(account *Account, proxyURL string) { + if c == nil { + return + } + if account == nil || account.ID <= 0 { + return + } + key := soraAccountProxyKey(account, proxyURL) + c.challengeCooldownMu.Lock() + _, existed := c.challengeCooldowns[key] + if existed { + delete(c.challengeCooldowns, key) + } + c.challengeCooldownMu.Unlock() + if existed && c.debugEnabled() { + c.debugLogf("cloudflare_challenge_cooldown_cleared key=%s", key) + } +} + +func (c *SoraDirectClient) sidecarSessionKey(account *Account, proxyURL string) string { + if c == nil || !c.sidecarSessionReuseEnabled() { + return "" + } + if account == nil || account.ID <= 0 { + return "" + } + key := soraAccountProxyKey(account, proxyURL) + now := time.Now() + ttlSeconds := c.sidecarSessionTTLSeconds() + + c.sidecarSessionMu.Lock() + defer c.sidecarSessionMu.Unlock() + c.cleanupExpiredSidecarSessionsLocked(now) + if existing, exists := c.sidecarSessions[key]; exists { + existing.LastUsedAt = now + c.sidecarSessions[key] = existing + return existing.SessionKey + } + + expiresAt := now.Add(time.Duration(ttlSeconds) * time.Second) + if ttlSeconds <= 0 { + expiresAt = now.Add(365 * 24 * time.Hour) + } + newEntry := soraSidecarSessionEntry{ + SessionKey: "sora-" + uuid.NewString(), + ExpiresAt: expiresAt, + LastUsedAt: now, + } + c.sidecarSessions[key] = newEntry + + if c.debugEnabled() { + c.debugLogf("sidecar_session_created key=%s ttl_s=%d", key, ttlSeconds) + } + return newEntry.SessionKey +} + +func (c *SoraDirectClient) cleanupExpiredChallengeCooldownsLocked(now time.Time) { + if c == nil || len(c.challengeCooldowns) == 0 { + return + } + for key, entry := range c.challengeCooldowns { + if !entry.Until.After(now) { + delete(c.challengeCooldowns, key) + } + } +} + +func (c *SoraDirectClient) cleanupExpiredSidecarSessionsLocked(now time.Time) { + if c == nil || len(c.sidecarSessions) == 0 { + return + } + for key, entry := range c.sidecarSessions { + if !entry.ExpiresAt.After(now) { + delete(c.sidecarSessions, key) + } + } +} + +func soraAccountProxyKey(account *Account, proxyURL string) string { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + return fmt.Sprintf("account:%d|proxy:%s", accountID, normalizeSoraProxyKey(proxyURL)) +} + +func normalizeSoraProxyKey(proxyURL string) string { + raw := strings.TrimSpace(proxyURL) + if raw == "" { + return "direct" + } + parsed, err := url.Parse(raw) + if err != nil { + return strings.ToLower(raw) + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + port := strings.TrimSpace(parsed.Port()) + if host == "" { + return strings.ToLower(raw) + } + if (scheme == "http" && port == "80") || (scheme == "https" && port == "443") { + port = "" + } + if port != "" { + host = host + ":" + port + } + if scheme == "" { + scheme = "proxy" + } + return scheme + "://" + host +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 9de1c164..a37e0d0a 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -43,10 +43,13 @@ func NewTokenRefreshService( stopCh: make(chan struct{}), } + openAIRefresher := NewOpenAITokenRefresher(openaiOAuthService, accountRepo) + openAIRefresher.SetSyncLinkedSoraAccounts(cfg.TokenRefresh.SyncLinkedSoraAccounts) + // 注册平台特定的刷新器 s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), - NewOpenAITokenRefresher(openaiOAuthService, accountRepo), + openAIRefresher, NewGeminiTokenRefresher(geminiOAuthService), NewAntigravityTokenRefresher(antigravityOAuthService), } diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 46033f75..0dd3cf45 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -86,6 +86,7 @@ type OpenAITokenRefresher struct { openaiOAuthService *OpenAIOAuthService accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 + syncLinkedSora bool } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -103,11 +104,15 @@ func (r *OpenAITokenRefresher) SetSoraAccountRepo(repo SoraAccountRepository) { r.soraAccountRepo = repo } +// SetSyncLinkedSoraAccounts 控制是否同步覆盖关联的 Sora 账号 token。 +func (r *OpenAITokenRefresher) SetSyncLinkedSoraAccounts(enabled bool) { + r.syncLinkedSora = enabled +} + // CanRefresh 检查是否能处理此账号 -// 只处理 openai 平台的 oauth 类型账号 +// 只处理 openai 平台的 oauth 类型账号(不直接刷新 sora 平台账号) func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { - return (account.Platform == PlatformOpenAI || account.Platform == PlatformSora) && - account.Type == AccountTypeOAuth + return account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth } // NeedsRefresh 检查token是否需要刷新 @@ -141,7 +146,7 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m } // 异步同步关联的 Sora 账号(不阻塞主流程) - if r.accountRepo != nil { + if r.accountRepo != nil && r.syncLinkedSora { go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) } diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index c7505037..264d7912 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -226,3 +226,43 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { }) } } + +func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { + refresher := &OpenAITokenRefresher{} + + tests := []struct { + name string + platform string + accType string + want bool + }{ + { + name: "openai oauth - can refresh", + platform: PlatformOpenAI, + accType: AccountTypeOAuth, + want: true, + }, + { + name: "sora oauth - cannot refresh directly", + platform: PlatformSora, + accType: AccountTypeOAuth, + want: false, + }, + { + name: "openai apikey - cannot refresh", + platform: PlatformOpenAI, + accType: AccountTypeAPIKey, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: tt.platform, + Type: tt.accType, + } + require.Equal(t, tt.want, refresher.CanRefresh(account)) + }) + } +} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index b8808e13..f9824183 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -26,8 +26,8 @@ type UsageLog struct { CacheCreationTokens int CacheReadTokens int - CacheCreation5mTokens int - CacheCreation1hTokens int + CacheCreation5mTokens int `gorm:"column:cache_creation_5m_tokens"` + CacheCreation1hTokens int `gorm:"column:cache_creation_1h_tokens"` InputCost float64 OutputCost float64 @@ -46,6 +46,9 @@ type UsageLog struct { UserAgent *string IPAddress *string + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + CacheTTLOverridden bool + // 图片生成字段 ImageCount int ImageSize *string diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 5d712f75..652f9e00 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -206,6 +206,18 @@ func ProvideSoraMediaStorage(cfg *config.Config) *SoraMediaStorage { return NewSoraMediaStorage(cfg) } +func ProvideSoraDirectClient( + cfg *config.Config, + httpUpstream HTTPUpstream, + tokenProvider *OpenAITokenProvider, + accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, +) *SoraDirectClient { + client := NewSoraDirectClient(cfg, httpUpstream, tokenProvider) + client.SetAccountRepositories(accountRepo, soraAccountRepo) + return client +} + // ProvideSoraMediaCleanupService 创建并启动 Sora 媒体清理服务 func ProvideSoraMediaCleanupService(storage *SoraMediaStorage, cfg *config.Config) *SoraMediaCleanupService { svc := NewSoraMediaCleanupService(storage, cfg) @@ -255,7 +267,7 @@ var ProviderSet = wire.NewSet( NewGatewayService, ProvideSoraMediaStorage, ProvideSoraMediaCleanupService, - NewSoraDirectClient, + ProvideSoraDirectClient, wire.Bind(new(SoraClient), new(*SoraDirectClient)), NewSoraGatewayService, NewOpenAIGatewayService, diff --git a/backend/internal/util/soraerror/soraerror.go b/backend/internal/util/soraerror/soraerror.go new file mode 100644 index 00000000..17712c10 --- /dev/null +++ b/backend/internal/util/soraerror/soraerror.go @@ -0,0 +1,170 @@ +package soraerror + +import ( + "encoding/json" + "fmt" + "net/http" + "regexp" + "strings" +) + +var ( + cfRayPattern = regexp.MustCompile(`(?i)cf-ray[:\s=]+([a-z0-9-]+)`) + cRayPattern = regexp.MustCompile(`(?i)cRay:\s*'([a-z0-9-]+)'`) + htmlChallenge = []string{ + "window._cf_chl_opt", + "just a moment", + "enable javascript and cookies to continue", + "__cf_chl_", + "challenge-platform", + } +) + +// IsCloudflareChallengeResponse reports whether the upstream response matches Cloudflare challenge behavior. +func IsCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + if statusCode != http.StatusForbidden && statusCode != http.StatusTooManyRequests { + return false + } + + if headers != nil && strings.EqualFold(strings.TrimSpace(headers.Get("cf-mitigated")), "challenge") { + return true + } + + preview := strings.ToLower(TruncateBody(body, 4096)) + for _, marker := range htmlChallenge { + if strings.Contains(preview, marker) { + return true + } + } + + contentType := "" + if headers != nil { + contentType = strings.ToLower(strings.TrimSpace(headers.Get("content-type"))) + } + if strings.Contains(contentType, "text/html") && + (strings.Contains(preview, "= 2 { + return strings.TrimSpace(matches[1]) + } + if matches := cRayPattern.FindStringSubmatch(preview); len(matches) >= 2 { + return strings.TrimSpace(matches[1]) + } + return "" +} + +// FormatCloudflareChallengeMessage appends cf-ray info when available. +func FormatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + rayID := ExtractCloudflareRayID(headers, body) + if rayID == "" { + return base + } + return fmt.Sprintf("%s (cf-ray: %s)", base, rayID) +} + +// ExtractUpstreamErrorCodeAndMessage extracts structured error code/message from common JSON layouts. +func ExtractUpstreamErrorCodeAndMessage(body []byte) (string, string) { + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return "", "" + } + if !json.Valid([]byte(trimmed)) { + return "", truncateMessage(trimmed, 256) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return "", truncateMessage(trimmed, 256) + } + + code := firstNonEmpty( + extractNestedString(payload, "error", "code"), + extractRootString(payload, "code"), + ) + message := firstNonEmpty( + extractNestedString(payload, "error", "message"), + extractRootString(payload, "message"), + extractNestedString(payload, "error", "detail"), + extractRootString(payload, "detail"), + ) + return strings.TrimSpace(code), truncateMessage(strings.TrimSpace(message), 512) +} + +// TruncateBody truncates body text for logging/inspection. +func TruncateBody(body []byte, max int) string { + if max <= 0 { + max = 512 + } + raw := strings.TrimSpace(string(body)) + if len(raw) <= max { + return raw + } + return raw[:max] + "...(truncated)" +} + +func truncateMessage(s string, max int) string { + if max <= 0 { + return "" + } + if len(s) <= max { + return s + } + return s[:max] + "...(truncated)" +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} + +func extractRootString(m map[string]any, key string) string { + if m == nil { + return "" + } + v, ok := m[key] + if !ok { + return "" + } + s, _ := v.(string) + return s +} + +func extractNestedString(m map[string]any, parent, key string) string { + if m == nil { + return "" + } + node, ok := m[parent] + if !ok { + return "" + } + child, ok := node.(map[string]any) + if !ok { + return "" + } + s, _ := child[key].(string) + return s +} diff --git a/backend/internal/util/soraerror/soraerror_test.go b/backend/internal/util/soraerror/soraerror_test.go new file mode 100644 index 00000000..4cf11169 --- /dev/null +++ b/backend/internal/util/soraerror/soraerror_test.go @@ -0,0 +1,47 @@ +package soraerror + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsCloudflareChallengeResponse(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-mitigated", "challenge") + require.True(t, IsCloudflareChallengeResponse(http.StatusForbidden, headers, []byte(`{"ok":false}`))) + + require.True(t, IsCloudflareChallengeResponse(http.StatusTooManyRequests, nil, []byte(`Just a moment...`))) + require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`Just a moment...`))) +} + +func TestExtractCloudflareRayID(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-ray", "9d01b0e9ecc35829-SEA") + require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil)) + + body := []byte(``) + require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body)) +} + +func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) { + code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`)) + require.Equal(t, "cf_shield_429", code) + require.Equal(t, "rate limited", msg) + + code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`)) + require.Equal(t, "unsupported_country_code", code) + require.Equal(t, "not available", msg) + + code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`)) + require.Equal(t, "", code) + require.Equal(t, "plain text", msg) +} + +func TestFormatCloudflareChallengeMessage(t *testing.T) { + headers := make(http.Header) + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + msg := FormatCloudflareChallengeMessage("blocked", headers, nil) + require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg) +} diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 7f37d59c..f7ba5c9e 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -86,6 +86,7 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || @@ -209,6 +210,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { if strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/sora/") || strings.HasPrefix(path, "/antigravity/") || strings.HasPrefix(path, "/setup/") || path == "/health" || diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index 50f5a323..e2cbcf15 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -362,6 +362,7 @@ func TestFrontendServer_Middleware(t *testing.T) { "/api/v1/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", @@ -537,6 +538,7 @@ func TestServeEmbeddedFrontend(t *testing.T) { "/api/users", "/v1/models", "/v1beta/chat", + "/sora/v1/models", "/antigravity/test", "/setup/init", "/health", diff --git a/backend/migrations/054_drop_legacy_cache_columns.sql b/backend/migrations/054_drop_legacy_cache_columns.sql new file mode 100644 index 00000000..040828c2 --- /dev/null +++ b/backend/migrations/054_drop_legacy_cache_columns.sql @@ -0,0 +1,44 @@ +-- Drop legacy cache token columns that lack the underscore separator. +-- These were created by GORM's automatic snake_case conversion: +-- CacheCreation5mTokens → cache_creation5m_tokens (incorrect) +-- CacheCreation1hTokens → cache_creation1h_tokens (incorrect) +-- +-- The canonical columns are: +-- cache_creation_5m_tokens (defined in 001_init.sql) +-- cache_creation_1h_tokens (defined in 001_init.sql) +-- +-- Migration 009 already copied data from legacy → canonical columns. +-- But upgraded instances may still have post-009 writes in legacy columns. +-- Backfill once more before dropping to prevent data loss. + +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'usage_logs' + AND column_name = 'cache_creation5m_tokens' + ) THEN + UPDATE usage_logs + SET cache_creation_5m_tokens = cache_creation5m_tokens + WHERE cache_creation_5m_tokens = 0 + AND cache_creation5m_tokens <> 0; + END IF; + + IF EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'usage_logs' + AND column_name = 'cache_creation1h_tokens' + ) THEN + UPDATE usage_logs + SET cache_creation_1h_tokens = cache_creation1h_tokens + WHERE cache_creation_1h_tokens = 0 + AND cache_creation1h_tokens <> 0; + END IF; +END $$; + +ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation5m_tokens; +ALTER TABLE usage_logs DROP COLUMN IF EXISTS cache_creation1h_tokens; diff --git a/backend/migrations/055_add_cache_ttl_overridden.sql b/backend/migrations/055_add_cache_ttl_overridden.sql new file mode 100644 index 00000000..0d42fcf7 --- /dev/null +++ b/backend/migrations/055_add_cache_ttl_overridden.sql @@ -0,0 +1,2 @@ +-- Add cache_ttl_overridden flag to usage_logs for tracking cache TTL override per account. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS cache_ttl_overridden BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 9fd2d391..c77ab70e 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -374,6 +374,9 @@ sora: # Max retries for upstream requests # 上游请求最大重试次数 max_retries: 3 + # Account+proxy cooldown window after Cloudflare challenge (seconds, 0 to disable) + # Cloudflare challenge 后按账号+代理冷却窗口(秒,0 表示关闭) + cloudflare_challenge_cooldown_seconds: 900 # Poll interval (seconds) # 轮询间隔(秒) poll_interval_seconds: 2 @@ -388,7 +391,11 @@ sora: recent_task_limit_max: 200 # Enable debug logs for Sora upstream requests # 启用 Sora 直连调试日志 + # 调试日志会输出上游请求尝试、重试、响应摘要;Authorization/openai-sentinel-token 等敏感头会自动脱敏 debug: false + # Allow Sora client to fetch token via OpenAI token provider + # 是否允许 Sora 客户端通过 OpenAI token provider 取 token(默认 false,避免误走 OpenAI 刷新链路) + use_openai_token_provider: false # Optional custom headers (key-value) # 额外请求头(键值对) headers: {} @@ -398,6 +405,27 @@ sora: # Disable TLS fingerprint for Sora upstream # 关闭 Sora 上游 TLS 指纹伪装 disable_tls_fingerprint: false + # curl_cffi sidecar for Sora only (required) + # 仅 Sora 链路使用的 curl_cffi sidecar(必需) + curl_cffi_sidecar: + # Sora 强制通过 sidecar 请求,必须启用 + # Sora is forced to use sidecar only; keep enabled=true + enabled: true + # Sidecar base URL (default endpoint: /request) + # sidecar 基础地址(默认请求端点:/request) + base_url: "http://sora-curl-cffi-sidecar:8080" + # curl_cffi impersonate profile, e.g. chrome131/chrome124/safari18_0 + # curl_cffi 指纹伪装 profile,例如 chrome131/chrome124/safari18_0 + impersonate: "chrome131" + # Sidecar request timeout (seconds) + # sidecar 请求超时(秒) + timeout_seconds: 60 + # Reuse session key per account+proxy to let sidecar persist cookies/session + # 按账号+代理复用 session key,让 sidecar 持久化 cookies/session + session_reuse_enabled: true + # Session TTL in sidecar (seconds) + # sidecar 会话 TTL(秒) + session_ttl_seconds: 3600 storage: # Storage type (local only for now) # 存储类型(首发仅支持 local) @@ -431,6 +459,13 @@ sora: # Cron 调度表达式 schedule: "0 3 * * *" +# Token refresh behavior +# token 刷新行为控制 +token_refresh: + # Whether OpenAI refresh flow is allowed to sync linked Sora accounts + # 是否允许 OpenAI 刷新流程同步覆盖 linked_openai_account_id 关联的 Sora 账号 token + sync_linked_sora_accounts: false + # ============================================================================= # API Key Auth Cache Configuration # API Key 认证缓存配置 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index e5c97bf8..f18a1b64 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -173,6 +173,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-sub2api} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} - POSTGRES_DB=${POSTGRES_DB:-sub2api} + - PGDATA=/var/lib/postgresql/data - TZ=${TZ:-Asia/Shanghai} networks: - sub2api-network diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 0c4856a9..89b11783 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -32,6 +32,7 @@ export async function list( platform?: string type?: string status?: string + group?: string search?: string }, options?: { @@ -271,7 +272,7 @@ export async function generateAuthUrl( */ export async function exchangeCode( endpoint: string, - exchangeData: { session_id: string; code: string; proxy_id?: number } + exchangeData: { session_id: string; code: string; state?: string; proxy_id?: number } ): Promise> { const { data } = await apiClient.post>(endpoint, exchangeData) return data @@ -493,7 +494,8 @@ export async function getAntigravityDefaultModelMapping(): Promise> { const payload: { refresh_token: string; proxy_id?: number } = { refresh_token: refreshToken @@ -501,7 +503,29 @@ export async function refreshOpenAIToken( if (proxyId) { payload.proxy_id = proxyId } - const { data } = await apiClient.post>('/admin/openai/refresh-token', payload) + const { data } = await apiClient.post>(endpoint, payload) + return data +} + +/** + * Validate Sora session token and exchange to access token + * @param sessionToken - Sora session token + * @param proxyId - Optional proxy ID + * @param endpoint - API endpoint path + * @returns Token information including access_token + */ +export async function validateSoraSessionToken( + sessionToken: string, + proxyId?: number | null, + endpoint: string = '/admin/sora/st2at' +): Promise> { + const payload: { session_token: string; proxy_id?: number } = { + session_token: sessionToken + } + if (proxyId) { + payload.proxy_id = proxyId + } + const { data } = await apiClient.post>(endpoint, payload) return data } @@ -527,6 +551,7 @@ export const accountsAPI = { generateAuthUrl, exchangeCode, refreshOpenAIToken, + validateSoraSessionToken, batchCreate, batchUpdateCredentials, bulkUpdate, diff --git a/frontend/src/api/admin/proxies.ts b/frontend/src/api/admin/proxies.ts index b6aaf595..5e31ae20 100644 --- a/frontend/src/api/admin/proxies.ts +++ b/frontend/src/api/admin/proxies.ts @@ -7,6 +7,7 @@ import { apiClient } from '../client' import type { Proxy, ProxyAccountSummary, + ProxyQualityCheckResult, CreateProxyRequest, UpdateProxyRequest, PaginatedResponse, @@ -143,6 +144,16 @@ export async function testProxy(id: number): Promise<{ return data } +/** + * Check proxy quality across common AI targets + * @param id - Proxy ID + * @returns Quality check result + */ +export async function checkProxyQuality(id: number): Promise { + const { data } = await apiClient.post(`/admin/proxies/${id}/quality-check`) + return data +} + /** * Get proxy usage statistics * @param id - Proxy ID @@ -248,6 +259,7 @@ export const proxiesAPI = { delete: deleteProxy, toggleStatus, testProxy, + checkProxyQuality, getStats, getProxyAccounts, batchCreate, diff --git a/frontend/src/components/account/AccountGroupsCell.vue b/frontend/src/components/account/AccountGroupsCell.vue index 512383a5..37771275 100644 --- a/frontend/src/components/account/AccountGroupsCell.vue +++ b/frontend/src/components/account/AccountGroupsCell.vue @@ -41,7 +41,7 @@ >
- {{ t('admin.accounts.allGroups', { count: groups.length }) }} + {{ t('admin.accounts.groupCountTotal', { count: groups.length }) }}
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -135,12 +141,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -156,10 +162,10 @@ + + + +
@@ -1538,6 +1592,46 @@
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }} +

+
+ +
+
+ + +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }} +

+
+
@@ -1707,32 +1801,6 @@
- -
- -
-
@@ -2108,6 +2178,7 @@ interface OAuthFlowExposed { projectId: string sessionKey: string refreshToken: string + sessionToken: string inputMethod: AuthInputMethod reset: () => void } @@ -2116,7 +2187,7 @@ const { t } = useI18n() const authStore = useAuthStore() const oauthStepTitle = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.oauth.openai.title') if (form.platform === 'gemini') return t('admin.accounts.oauth.gemini.title') if (form.platform === 'antigravity') return t('admin.accounts.oauth.antigravity.title') return t('admin.accounts.oauth.title') @@ -2124,13 +2195,13 @@ const oauthStepTitle = computed(() => { // Platform-specific hints for API Key type const baseUrlHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.baseUrlHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') return t('admin.accounts.baseUrlHint') }) const apiKeyHint = computed(() => { - if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') + if (form.platform === 'openai' || form.platform === 'sora') return t('admin.accounts.openai.apiKeyHint') if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') return t('admin.accounts.apiKeyHint') }) @@ -2151,34 +2222,36 @@ const appStore = useAppStore() // OAuth composables const oauth = useAccountOAuth() // For Anthropic OAuth -const openaiOAuth = useOpenAIOAuth() // For OpenAI OAuth +const openaiOAuth = useOpenAIOAuth({ platform: 'openai' }) // For OpenAI OAuth +const soraOAuth = useOpenAIOAuth({ platform: 'sora' }) // For Sora OAuth const geminiOAuth = useGeminiOAuth() // For Gemini OAuth const antigravityOAuth = useAntigravityOAuth() // For Antigravity OAuth +const activeOpenAIOAuth = computed(() => (form.platform === 'sora' ? soraOAuth : openaiOAuth)) // Computed: current OAuth state for template binding const currentAuthUrl = computed(() => { - if (form.platform === 'openai') return openaiOAuth.authUrl.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.authUrl.value if (form.platform === 'gemini') return geminiOAuth.authUrl.value if (form.platform === 'antigravity') return antigravityOAuth.authUrl.value return oauth.authUrl.value }) const currentSessionId = computed(() => { - if (form.platform === 'openai') return openaiOAuth.sessionId.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.sessionId.value if (form.platform === 'gemini') return geminiOAuth.sessionId.value if (form.platform === 'antigravity') return antigravityOAuth.sessionId.value return oauth.sessionId.value }) const currentOAuthLoading = computed(() => { - if (form.platform === 'openai') return openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.loading.value if (form.platform === 'gemini') return geminiOAuth.loading.value if (form.platform === 'antigravity') return antigravityOAuth.loading.value return oauth.loading.value }) const currentOAuthError = computed(() => { - if (form.platform === 'openai') return openaiOAuth.error.value + if (form.platform === 'openai' || form.platform === 'sora') return activeOpenAIOAuth.value.error.value if (form.platform === 'gemini') return geminiOAuth.error.value if (form.platform === 'antigravity') return antigravityOAuth.error.value return oauth.error.value @@ -2217,7 +2290,6 @@ const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) const openaiPassthroughEnabled = ref(false) const codexCLIOnlyEnabled = ref(false) -const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream const upstreamBaseUrl = ref('') // For upstream type: base URL @@ -2250,6 +2322,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') @@ -2356,8 +2430,8 @@ const expiresAtInput = computed({ const canExchangeCode = computed(() => { const authCode = oauthFlowRef.value?.authCode || '' - if (form.platform === 'openai') { - return authCode.trim() && openaiOAuth.sessionId.value && !openaiOAuth.loading.value + if (form.platform === 'openai' || form.platform === 'sora') { + return authCode.trim() && activeOpenAIOAuth.value.sessionId.value && !activeOpenAIOAuth.value.loading.value } if (form.platform === 'gemini') { return authCode.trim() && geminiOAuth.sessionId.value && !geminiOAuth.loading.value @@ -2417,7 +2491,7 @@ watch( (newPlatform) => { // Reset base URL based on platform apiKeyBaseUrl.value = - newPlatform === 'openai' + (newPlatform === 'openai' || newPlatform === 'sora') ? 'https://api.openai.com' : newPlatform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2443,6 +2517,11 @@ watch( if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false } + if (newPlatform === 'sora') { + accountCategory.value = 'oauth-based' + addMethod.value = 'oauth' + form.type = 'oauth' + } if (newPlatform !== 'openai') { openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false @@ -2450,6 +2529,7 @@ watch( // Reset OAuth states oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() } @@ -2711,7 +2791,6 @@ const resetForm = () => { autoPauseOnExpired.value = true openaiPassthroughEnabled.value = false codexCLIOnlyEnabled.value = false - enableSoraOnOpenAIOAuth.value = false // Reset quota control state windowCostEnabled.value = false windowCostLimit.value = null @@ -2721,6 +2800,8 @@ const resetForm = () => { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' upstreamApiKey.value = '' @@ -2732,6 +2813,7 @@ const resetForm = () => { geminiTierAIStudio.value = 'aistudio_free' oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -2763,6 +2845,23 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined } +const buildSoraExtra = ( + base?: Record, + linkedOpenAIAccountId?: string | number +): Record | undefined => { + const extra: Record = { ...(base || {}) } + if (linkedOpenAIAccountId !== undefined && linkedOpenAIAccountId !== null) { + const id = String(linkedOpenAIAccountId).trim() + if (id) { + extra.linked_openai_account_id = id + } + } + delete extra.openai_passthrough + delete extra.openai_oauth_passthrough + delete extra.codex_cli_only + return Object.keys(extra).length > 0 ? extra : undefined +} + // Helper function to create account with mixed channel warning handling const doCreateAccount = async (payload: any) => { submitting.value = true @@ -2878,7 +2977,7 @@ const handleSubmit = async () => { // Determine default base URL based on platform const defaultBaseUrl = - form.platform === 'openai' + (form.platform === 'openai' || form.platform === 'sora') ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2930,14 +3029,15 @@ const goBackToBasicInfo = () => { step.value = 1 oauth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() } const handleGenerateUrl = async () => { - if (form.platform === 'openai') { - await openaiOAuth.generateAuthUrl(form.proxy_id) + if (form.platform === 'openai' || form.platform === 'sora') { + await activeOpenAIOAuth.value.generateAuthUrl(form.proxy_id) } else if (form.platform === 'gemini') { await geminiOAuth.generateAuthUrl( form.proxy_id, @@ -2953,13 +3053,19 @@ const handleGenerateUrl = async () => { } const handleValidateRefreshToken = (rt: string) => { - if (form.platform === 'openai') { + if (form.platform === 'openai' || form.platform === 'sora') { handleOpenAIValidateRT(rt) } else if (form.platform === 'antigravity') { handleAntigravityValidateRT(rt) } } +const handleValidateSessionToken = (sessionToken: string) => { + if (form.platform === 'sora') { + handleSoraValidateST(sessionToken) + } +} + const formatDateTimeLocal = formatDateTimeLocalInput const parseDateTimeLocal = parseDateTimeLocalInput @@ -2995,100 +3101,101 @@ const createAccountAndFinish = async ( // OpenAI OAuth 授权码兑换 const handleOpenAIExchange = async (authCode: string) => { - if (!authCode.trim() || !openaiOAuth.sessionId.value) return + const oauthClient = activeOpenAIOAuth.value + if (!authCode.trim() || !oauthClient.sessionId.value) return - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' try { - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } + + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), - openaiOAuth.sessionId.value, + oauthClient.sessionId.value, + stateToUse, form.proxy_id ) if (!tokenInfo) return - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' // 应用临时不可调度配置 if (!applyTempUnschedConfig(credentials)) { return } - // 1. 创建 OpenAI 账号 - const openaiAccount = await adminAPI.accounts.create({ - name: form.name, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined - appStore.showSuccess(t('admin.accounts.accountCreated')) + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: form.name, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + appStore.showSuccess(t('admin.accounts.accountCreated')) + } - // 2. 如果启用了 Sora,同时创建 Sora 账号 - if (enableSoraOnOpenAIOAuth.value) { - try { - // Sora 使用相同的 OAuth credentials - const soraCredentials = { - access_token: credentials.access_token, - refresh_token: credentials.refresh_token, - expires_at: credentials.expires_at - } - - // 建立关联关系 - const soraExtra: Record = { - ...(extra || {}), - linked_openai_account_id: String(openaiAccount.id) - } - delete soraExtra.openai_passthrough - delete soraExtra.openai_oauth_passthrough - - await adminAPI.accounts.create({ - name: `${form.name} (Sora)`, - notes: form.notes, - platform: 'sora', - type: 'oauth', - credentials: soraCredentials, - extra: soraExtra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) - - appStore.showSuccess(t('admin.accounts.soraAccountCreated')) - } catch (error: any) { - console.error('创建 Sora 账号失败:', error) - appStore.showWarning(t('admin.accounts.soraAccountFailed')) + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at } + + const soraName = shouldCreateOpenAI ? `${form.name} (Sora)` : form.name + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + appStore.showSuccess(t('admin.accounts.accountCreated')) } emit('created') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false } } // OpenAI 手动 RT 批量验证和创建 const handleOpenAIValidateRT = async (refreshTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value if (!refreshTokenInput.trim()) return // Parse multiple refresh tokens (one per line) @@ -3098,53 +3205,86 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { .filter((rt) => rt) if (refreshTokens.length === 0) { - openaiOAuth.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterRefreshToken') return } - openaiOAuth.loading.value = true - openaiOAuth.error.value = '' + oauthClient.loading.value = true + oauthClient.error.value = '' let successCount = 0 let failedCount = 0 const errors: string[] = [] + const shouldCreateOpenAI = form.platform === 'openai' + const shouldCreateSora = form.platform === 'sora' try { for (let i = 0; i < refreshTokens.length; i++) { try { - const tokenInfo = await openaiOAuth.validateRefreshToken( + const tokenInfo = await oauthClient.validateRefreshToken( refreshTokens[i], form.proxy_id ) if (!tokenInfo) { failedCount++ - errors.push(`#${i + 1}: ${openaiOAuth.error.value || 'Validation failed'}`) - openaiOAuth.error.value = '' + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' continue } - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const oauthExtra = openaiOAuth.buildExtraInfo(tokenInfo) as Record | undefined + const credentials = oauthClient.buildCredentials(tokenInfo) + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined const extra = buildOpenAIExtra(oauthExtra) // Generate account name with index for batch const accountName = refreshTokens.length > 1 ? `${form.name} #${i + 1}` : form.name - await adminAPI.accounts.create({ - name: accountName, - notes: form.notes, - platform: 'openai', - type: 'oauth', - credentials, - extra, - proxy_id: form.proxy_id, - concurrency: form.concurrency, - priority: form.priority, - rate_multiplier: form.rate_multiplier, - group_ids: form.group_ids, - expires_at: form.expires_at, - auto_pause_on_expired: autoPauseOnExpired.value - }) + let openaiAccountId: string | number | undefined + + if (shouldCreateOpenAI) { + const openaiAccount = await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'openai', + type: 'oauth', + credentials, + extra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + openaiAccountId = openaiAccount.id + } + + if (shouldCreateSora) { + const soraCredentials = { + access_token: credentials.access_token, + refresh_token: credentials.refresh_token, + expires_at: credentials.expires_at + } + const soraName = shouldCreateOpenAI ? `${accountName} (Sora)` : accountName + const soraExtra = buildSoraExtra(shouldCreateOpenAI ? extra : oauthExtra, openaiAccountId) + await adminAPI.accounts.create({ + name: soraName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials: soraCredentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + } + successCount++ } catch (error: any) { failedCount++ @@ -3166,14 +3306,99 @@ const handleOpenAIValidateRT = async (refreshTokenInput: string) => { appStore.showWarning( t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) ) - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') emit('created') } else { - openaiOAuth.error.value = errors.join('\n') + oauthClient.error.value = errors.join('\n') appStore.showError(t('admin.accounts.oauth.batchFailed')) } } finally { - openaiOAuth.loading.value = false + oauthClient.loading.value = false + } +} + +// Sora 手动 ST 批量验证和创建 +const handleSoraValidateST = async (sessionTokenInput: string) => { + const oauthClient = activeOpenAIOAuth.value + if (!sessionTokenInput.trim()) return + + const sessionTokens = sessionTokenInput + .split('\n') + .map((st) => st.trim()) + .filter((st) => st) + + if (sessionTokens.length === 0) { + oauthClient.error.value = t('admin.accounts.oauth.openai.pleaseEnterSessionToken') + return + } + + oauthClient.loading.value = true + oauthClient.error.value = '' + + let successCount = 0 + let failedCount = 0 + const errors: string[] = [] + + try { + for (let i = 0; i < sessionTokens.length; i++) { + try { + const tokenInfo = await oauthClient.validateSessionToken(sessionTokens[i], form.proxy_id) + if (!tokenInfo) { + failedCount++ + errors.push(`#${i + 1}: ${oauthClient.error.value || 'Validation failed'}`) + oauthClient.error.value = '' + continue + } + + const credentials = oauthClient.buildCredentials(tokenInfo) + credentials.session_token = sessionTokens[i] + const oauthExtra = oauthClient.buildExtraInfo(tokenInfo) as Record | undefined + const soraExtra = buildSoraExtra(oauthExtra) + + const accountName = sessionTokens.length > 1 ? `${form.name} #${i + 1}` : form.name + await adminAPI.accounts.create({ + name: accountName, + notes: form.notes, + platform: 'sora', + type: 'oauth', + credentials, + extra: soraExtra, + proxy_id: form.proxy_id, + concurrency: form.concurrency, + priority: form.priority, + rate_multiplier: form.rate_multiplier, + group_ids: form.group_ids, + expires_at: form.expires_at, + auto_pause_on_expired: autoPauseOnExpired.value + }) + successCount++ + } catch (error: any) { + failedCount++ + const errMsg = error.response?.data?.detail || error.message || 'Unknown error' + errors.push(`#${i + 1}: ${errMsg}`) + } + } + + if (successCount > 0 && failedCount === 0) { + appStore.showSuccess( + sessionTokens.length > 1 + ? t('admin.accounts.oauth.batchSuccess', { count: successCount }) + : t('admin.accounts.accountCreated') + ) + emit('created') + handleClose() + } else if (successCount > 0 && failedCount > 0) { + appStore.showWarning( + t('admin.accounts.oauth.batchPartialSuccess', { success: successCount, failed: failedCount }) + ) + oauthClient.error.value = errors.join('\n') + emit('created') + } else { + oauthClient.error.value = errors.join('\n') + appStore.showError(t('admin.accounts.oauth.batchFailed')) + } + } finally { + oauthClient.loading.value = false } } @@ -3393,6 +3618,12 @@ const handleAnthropicExchange = async (authCode: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const credentials = { ...tokenInfo, ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) @@ -3412,6 +3643,7 @@ const handleExchangeCode = async () => { switch (form.platform) { case 'openai': + case 'sora': return handleOpenAIExchange(authCode) case 'gemini': return handleGeminiExchange(authCode) @@ -3486,6 +3718,12 @@ const handleCookieAuth = async (sessionKey: string) => { extra.session_id_masking_enabled = true } + // Add cache TTL override settings + if (cacheTTLOverrideEnabled.value) { + extra.cache_ttl_override_enabled = true + extra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name // Merge interceptWarmupRequests into credentials diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 0b6d00c9..3842ea06 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -975,6 +975,46 @@
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.hint') }} +

+
+ +
+
+ + +

+ {{ t('admin.accounts.quotaControl.cacheTTLOverride.targetHint') }} +

+
+
@@ -1177,6 +1217,8 @@ const maxSessions = ref(null) const sessionIdleTimeout = ref(null) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) +const cacheTTLOverrideEnabled = ref(false) +const cacheTTLOverrideTarget = ref('5m') // OpenAI 自动透传开关(OAuth/API Key) const openaiPassthroughEnabled = ref(false) @@ -1581,6 +1623,8 @@ function loadQuotaControlSettings(account: Account) { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + cacheTTLOverrideEnabled.value = false + cacheTTLOverrideTarget.value = '5m' // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -1609,6 +1653,12 @@ function loadQuotaControlSettings(account: Account) { if (account.session_id_masking_enabled === true) { sessionIdMaskingEnabled.value = true } + + // Load cache TTL override setting + if (account.cache_ttl_override_enabled === true) { + cacheTTLOverrideEnabled.value = true + cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' + } } function formatTempUnschedKeywords(value: unknown) { @@ -1820,6 +1870,15 @@ const handleSubmit = async () => { delete newExtra.session_id_masking_enabled } + // Cache TTL override setting + if (cacheTTLOverrideEnabled.value) { + newExtra.cache_ttl_override_enabled = true + newExtra.cache_ttl_override_target = cacheTTLOverrideTarget.value + } else { + delete newExtra.cache_ttl_override_enabled + delete newExtra.cache_ttl_override_target + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 9c4b7e4b..8e00d25b 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -48,6 +48,17 @@ t(getOAuthKey('refreshTokenAuth')) }} +
@@ -135,6 +146,87 @@ + +
+
+

+ {{ t(getOAuthKey('sessionTokenDesc')) }} +

+ +
+ + +

+ {{ t('admin.accounts.oauth.batchCreateAccounts', { count: parsedSessionTokenCount }) }} +

+
+ +
+

+ {{ error }} +

+
+ + +
+
+
(), { authUrl: '', @@ -540,6 +633,7 @@ const props = withDefaults(defineProps(), { methodLabel: 'Authorization Method', showCookieOption: true, showRefreshTokenOption: false, + showSessionTokenOption: false, platform: 'anthropic', showProjectId: true }) @@ -549,6 +643,7 @@ const emit = defineEmits<{ 'exchange-code': [code: string] 'cookie-auth': [sessionKey: string] 'validate-refresh-token': [refreshToken: string] + 'validate-session-token': [sessionToken: string] 'update:inputMethod': [method: AuthInputMethod] }>() @@ -587,12 +682,13 @@ const inputMethod = ref(props.showCookieOption ? 'manual' : 'ma const authCodeInput = ref('') const sessionKeyInput = ref('') const refreshTokenInput = ref('') +const sessionTokenInput = ref('') const showHelpDialog = ref(false) const oauthState = ref('') const projectId = ref('') // Computed: show method selection when either cookie or refresh token option is enabled -const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption) +const showMethodSelection = computed(() => props.showCookieOption || props.showRefreshTokenOption || props.showSessionTokenOption) // Clipboard const { copied, copyToClipboard } = useClipboard() @@ -613,6 +709,13 @@ const parsedRefreshTokenCount = computed(() => { .filter((rt) => rt).length }) +const parsedSessionTokenCount = computed(() => { + return sessionTokenInput.value + .split('\n') + .map((st) => st.trim()) + .filter((st) => st).length +}) + // Watchers watch(inputMethod, (newVal) => { emit('update:inputMethod', newVal) @@ -631,7 +734,7 @@ watch(authCodeInput, (newVal) => { const url = new URL(trimmed) const code = url.searchParams.get('code') const stateParam = url.searchParams.get('state') - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateParam) { oauthState.value = stateParam } if (code && code !== trimmed) { @@ -642,7 +745,7 @@ watch(authCodeInput, (newVal) => { // If URL parsing fails, try regex extraction const match = trimmed.match(/[?&]code=([^&]+)/) const stateMatch = trimmed.match(/[?&]state=([^&]+)/) - if ((props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { + if ((props.platform === 'openai' || props.platform === 'sora' || props.platform === 'gemini' || props.platform === 'antigravity') && stateMatch && stateMatch[1]) { oauthState.value = stateMatch[1] } if (match && match[1] && match[1] !== trimmed) { @@ -680,6 +783,12 @@ const handleValidateRefreshToken = () => { } } +const handleValidateSessionToken = () => { + if (sessionTokenInput.value.trim()) { + emit('validate-session-token', sessionTokenInput.value.trim()) + } +} + // Expose methods and state defineExpose({ authCode: authCodeInput, @@ -687,6 +796,7 @@ defineExpose({ projectId, sessionKey: sessionKeyInput, refreshToken: refreshTokenInput, + sessionToken: sessionTokenInput, inputMethod, reset: () => { authCodeInput.value = '' @@ -694,6 +804,7 @@ defineExpose({ projectId.value = '' sessionKeyInput.value = '' refreshTokenInput.value = '' + sessionTokenInput.value = '' inputMethod.value = 'manual' showHelpDialog.value = false } diff --git a/frontend/src/components/account/ReAuthAccountModal.vue b/frontend/src/components/account/ReAuthAccountModal.vue index b2734b4f..aab0fe7d 100644 --- a/frontend/src/components/account/ReAuthAccountModal.vue +++ b/frontend/src/components/account/ReAuthAccountModal.vue @@ -14,7 +14,7 @@
('code_as // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') +const isSora = computed(() => props.account?.platform === 'sora') +const isOpenAILike = computed(() => isOpenAI.value || isSora.value) const isGemini = computed(() => props.account?.platform === 'gemini') const isAnthropic = computed(() => props.account?.platform === 'anthropic') const isAntigravity = computed(() => props.account?.platform === 'antigravity') +const activeOpenAIOAuth = computed(() => (isSora.value ? soraOAuth : openaiOAuth)) // Computed - current OAuth state based on platform const currentAuthUrl = computed(() => { - if (isOpenAI.value) return openaiOAuth.authUrl.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.authUrl.value if (isGemini.value) return geminiOAuth.authUrl.value if (isAntigravity.value) return antigravityOAuth.authUrl.value return claudeOAuth.authUrl.value }) const currentSessionId = computed(() => { - if (isOpenAI.value) return openaiOAuth.sessionId.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.sessionId.value if (isGemini.value) return geminiOAuth.sessionId.value if (isAntigravity.value) return antigravityOAuth.sessionId.value return claudeOAuth.sessionId.value }) const currentLoading = computed(() => { - if (isOpenAI.value) return openaiOAuth.loading.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.loading.value if (isGemini.value) return geminiOAuth.loading.value if (isAntigravity.value) return antigravityOAuth.loading.value return claudeOAuth.loading.value }) const currentError = computed(() => { - if (isOpenAI.value) return openaiOAuth.error.value + if (isOpenAILike.value) return activeOpenAIOAuth.value.error.value if (isGemini.value) return geminiOAuth.error.value if (isAntigravity.value) return antigravityOAuth.error.value return claudeOAuth.error.value @@ -269,8 +275,8 @@ const currentError = computed(() => { // Computed const isManualInputMethod = computed(() => { - // OpenAI/Gemini/Antigravity always use manual input (no cookie auth option) - return isOpenAI.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' + // OpenAI/Sora/Gemini/Antigravity always use manual input (no cookie auth option) + return isOpenAILike.value || isGemini.value || isAntigravity.value || oauthFlowRef.value?.inputMethod === 'manual' }) const canExchangeCode = computed(() => { @@ -313,6 +319,7 @@ const resetState = () => { geminiOAuthType.value = 'code_assist' claudeOAuth.resetState() openaiOAuth.resetState() + soraOAuth.resetState() geminiOAuth.resetState() antigravityOAuth.resetState() oauthFlowRef.value?.reset() @@ -325,8 +332,8 @@ const handleClose = () => { const handleGenerateUrl = async () => { if (!props.account) return - if (isOpenAI.value) { - await openaiOAuth.generateAuthUrl(props.account.proxy_id) + if (isOpenAILike.value) { + await activeOpenAIOAuth.value.generateAuthUrl(props.account.proxy_id) } else if (isGemini.value) { const creds = (props.account.credentials || {}) as Record const tierId = typeof creds.tier_id === 'string' ? creds.tier_id : undefined @@ -345,21 +352,29 @@ const handleExchangeCode = async () => { const authCode = oauthFlowRef.value?.authCode || '' if (!authCode.trim()) return - if (isOpenAI.value) { + if (isOpenAILike.value) { // OpenAI OAuth flow - const sessionId = openaiOAuth.sessionId.value + const oauthClient = activeOpenAIOAuth.value + const sessionId = oauthClient.sessionId.value if (!sessionId) return + const stateToUse = (oauthFlowRef.value?.oauthState || oauthClient.oauthState.value || '').trim() + if (!stateToUse) { + oauthClient.error.value = t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) + return + } - const tokenInfo = await openaiOAuth.exchangeAuthCode( + const tokenInfo = await oauthClient.exchangeAuthCode( authCode.trim(), sessionId, + stateToUse, props.account.proxy_id ) if (!tokenInfo) return // Build credentials and extra info - const credentials = openaiOAuth.buildCredentials(tokenInfo) - const extra = openaiOAuth.buildExtraInfo(tokenInfo) + const credentials = oauthClient.buildCredentials(tokenInfo) + const extra = oauthClient.buildExtraInfo(tokenInfo) try { // Update account with new credentials @@ -376,8 +391,8 @@ const handleExchangeCode = async () => { emit('reauthorized') handleClose() } catch (error: any) { - openaiOAuth.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') - appStore.showError(openaiOAuth.error.value) + oauthClient.error.value = error.response?.data?.detail || t('admin.accounts.oauth.authFailed') + appStore.showError(oauthClient.error.value) } } else if (isGemini.value) { const sessionId = geminiOAuth.sessionId.value @@ -490,7 +505,7 @@ const handleExchangeCode = async () => { } const handleCookieAuth = async (sessionKey: string) => { - if (!props.account || isOpenAI.value) return + if (!props.account || isOpenAILike.value) return claudeOAuth.loading.value = true claudeOAuth.error.value = '' diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 01e7fcdd..5280e787 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -10,16 +10,21 @@
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index feb09654..a25c25cc 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -41,7 +41,7 @@
-
+
@@ -54,6 +54,12 @@ :placeholder="loadingModels ? t('common.loading') + '...' : t('admin.accounts.selectTestModel')" />
+
+ {{ t('admin.accounts.soraTestHint') }} +
@@ -114,12 +120,12 @@
- {{ t('admin.accounts.testModel') }} + {{ isSoraAccount ? t('admin.accounts.soraTestTarget') : t('admin.accounts.testModel') }}
- {{ t('admin.accounts.testPrompt') }} + {{ isSoraAccount ? t('admin.accounts.soraTestMode') : t('admin.accounts.testPrompt') }}
@@ -135,10 +141,10 @@
@@ -157,9 +159,36 @@ {{ t('admin.usage.outputTokens') }} {{ tokenTooltipData.output_tokens.toLocaleString() }} -
- {{ t('admin.usage.cacheCreationTokens') }} - {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+ + + +
+ {{ t('admin.usage.cacheCreationTokens') }} + {{ tokenTooltipData.cache_creation_tokens.toLocaleString() }} +
+
+
+ + {{ t('usage.cacheTtlOverriddenLabel') }} + R-{{ tokenTooltipData.cache_creation_1h_tokens > 0 ? '5m' : '1H' }} + + {{ tokenTooltipData.cache_creation_1h_tokens > 0 ? t('usage.cacheTtlOverridden1h') : t('usage.cacheTtlOverridden5m') }}
{{ t('admin.usage.cacheReadTokens') }} diff --git a/frontend/src/components/common/StatCard.vue b/frontend/src/components/common/StatCard.vue index 203a2fa8..d7c40a2e 100644 --- a/frontend/src/components/common/StatCard.vue +++ b/frontend/src/components/common/StatCard.vue @@ -6,7 +6,7 @@

{{ title }}

-

{{ formattedValue }}

+

{{ formattedValue }}

- Logo + Logo
@@ -167,6 +167,7 @@ const isDark = ref(document.documentElement.classList.contains('dark')) const siteName = computed(() => appStore.siteName) const siteLogo = computed(() => appStore.siteLogo) const siteVersion = computed(() => appStore.siteVersion) +const settingsLoaded = computed(() => appStore.publicSettingsLoaded) // SVG Icon Components const DashboardIcon = { diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index ca200cb3..6f53404c 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' | 'session_token' export interface OAuthState { authUrl: string diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index 1193c45d..ec40a3f1 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -39,6 +39,7 @@ export const claudeModels = [ 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001', 'claude-opus-4-5-20251101', 'claude-opus-4-6', + 'claude-sonnet-4-6', 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' ] @@ -250,6 +251,7 @@ export const allModels = allModelsList.map(m => ({ value: m, label: m })) const anthropicPresetMappings = [ { label: 'Sonnet 4', from: 'claude-sonnet-4-20250514', to: 'claude-sonnet-4-20250514', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5-20250929', to: 'claude-sonnet-4-5-20250929', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, + { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' }, diff --git a/frontend/src/composables/useOpenAIOAuth.ts b/frontend/src/composables/useOpenAIOAuth.ts index 82a77031..32045cbe 100644 --- a/frontend/src/composables/useOpenAIOAuth.ts +++ b/frontend/src/composables/useOpenAIOAuth.ts @@ -19,12 +19,21 @@ export interface OpenAITokenInfo { [key: string]: unknown } -export function useOpenAIOAuth() { +export type OpenAIOAuthPlatform = 'openai' | 'sora' + +interface UseOpenAIOAuthOptions { + platform?: OpenAIOAuthPlatform +} + +export function useOpenAIOAuth(options?: UseOpenAIOAuthOptions) { const appStore = useAppStore() + const oauthPlatform = options?.platform ?? 'openai' + const endpointPrefix = oauthPlatform === 'sora' ? '/admin/sora' : '/admin/openai' // State const authUrl = ref('') const sessionId = ref('') + const oauthState = ref('') const loading = ref(false) const error = ref('') @@ -32,6 +41,7 @@ export function useOpenAIOAuth() { const resetState = () => { authUrl.value = '' sessionId.value = '' + oauthState.value = '' loading.value = false error.value = '' } @@ -44,6 +54,7 @@ export function useOpenAIOAuth() { loading.value = true authUrl.value = '' sessionId.value = '' + oauthState.value = '' error.value = '' try { @@ -56,11 +67,17 @@ export function useOpenAIOAuth() { } const response = await adminAPI.accounts.generateAuthUrl( - '/admin/openai/generate-auth-url', + `${endpointPrefix}/generate-auth-url`, payload ) authUrl.value = response.auth_url sessionId.value = response.session_id + try { + const parsed = new URL(response.auth_url) + oauthState.value = parsed.searchParams.get('state') || '' + } catch { + oauthState.value = '' + } return true } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to generate OpenAI auth URL' @@ -75,10 +92,11 @@ export function useOpenAIOAuth() { const exchangeAuthCode = async ( code: string, currentSessionId: string, + state: string, proxyId?: number | null ): Promise => { - if (!code.trim() || !currentSessionId) { - error.value = 'Missing auth code or session ID' + if (!code.trim() || !currentSessionId || !state.trim()) { + error.value = 'Missing auth code, session ID, or state' return null } @@ -86,15 +104,16 @@ export function useOpenAIOAuth() { error.value = '' try { - const payload: { session_id: string; code: string; proxy_id?: number } = { + const payload: { session_id: string; code: string; state: string; proxy_id?: number } = { session_id: currentSessionId, - code: code.trim() + code: code.trim(), + state: state.trim() } if (proxyId) { payload.proxy_id = proxyId } - const tokenInfo = await adminAPI.accounts.exchangeCode('/admin/openai/exchange-code', payload) + const tokenInfo = await adminAPI.accounts.exchangeCode(`${endpointPrefix}/exchange-code`, payload) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to exchange OpenAI auth code' @@ -120,7 +139,11 @@ export function useOpenAIOAuth() { try { // Use dedicated refresh-token endpoint - const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId) + const tokenInfo = await adminAPI.accounts.refreshOpenAIToken( + refreshToken.trim(), + proxyId, + `${endpointPrefix}/refresh-token` + ) return tokenInfo as OpenAITokenInfo } catch (err: any) { error.value = err.response?.data?.detail || 'Failed to validate refresh token' @@ -131,6 +154,33 @@ export function useOpenAIOAuth() { } } + // Validate Sora session token and get access token + const validateSessionToken = async ( + sessionToken: string, + proxyId?: number | null + ): Promise => { + if (!sessionToken.trim()) { + error.value = 'Missing session token' + return null + } + loading.value = true + error.value = '' + try { + const tokenInfo = await adminAPI.accounts.validateSoraSessionToken( + sessionToken.trim(), + proxyId, + `${endpointPrefix}/st2at` + ) + return tokenInfo as OpenAITokenInfo + } catch (err: any) { + error.value = err.response?.data?.detail || 'Failed to validate session token' + appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + // Build credentials for OpenAI OAuth account const buildCredentials = (tokenInfo: OpenAITokenInfo): Record => { const creds: Record = { @@ -172,6 +222,7 @@ export function useOpenAIOAuth() { // State authUrl, sessionId, + oauthState, loading, error, // Methods @@ -179,6 +230,7 @@ export function useOpenAIOAuth() { generateAuthUrl, exchangeAuthCode, validateRefreshToken, + validateSessionToken, buildCredentials, buildExtraInfo } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 967f22c9..bb292683 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -576,6 +576,10 @@ export default { description: 'View and analyze your API usage history', costDetails: 'Cost Breakdown', tokenDetails: 'Token Breakdown', + cacheTtlOverriddenHint: 'Cache TTL Override enabled', + cacheTtlOverriddenLabel: 'TTL Override', + cacheTtlOverridden5m: 'Billed as 5m', + cacheTtlOverridden1h: 'Billed as 1h', totalRequests: 'Total Requests', totalTokens: 'Total Tokens', totalCost: 'Total Cost', @@ -1346,6 +1350,7 @@ export default { allPlatforms: 'All Platforms', allTypes: 'All Types', allStatus: 'All Status', + allGroups: 'All Groups', oauthType: 'OAuth', setupToken: 'Setup Token', apiKey: 'API Key', @@ -1355,7 +1360,7 @@ export default { schedulableEnabled: 'Scheduling enabled', schedulableDisabled: 'Scheduling disabled', failedToToggleSchedulable: 'Failed to toggle scheduling status', - allGroups: '{count} groups total', + groupCountTotal: '{count} groups total', platforms: { anthropic: 'Anthropic', claude: 'Claude', @@ -1618,6 +1623,12 @@ export default { sessionIdMasking: { label: 'Session ID Masking', hint: 'When enabled, fixes the session ID in metadata.user_id for 15 minutes, making upstream think requests come from the same session' + }, + cacheTTLOverride: { + label: 'Cache TTL Override', + hint: 'Force all cache creation tokens to be billed as the selected TTL tier (5m or 1h)', + target: 'Target TTL', + targetHint: 'Select the TTL tier for billing' } }, expired: 'Expired', @@ -1731,9 +1742,13 @@ export default { refreshTokenAuth: 'Manual RT Input', refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', + sessionTokenAuth: 'Manual ST Input', + sessionTokenDesc: 'Enter your existing Sora Session Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', + sessionTokenPlaceholder: 'Paste your Sora Session Token...\nSupports multiple, one per line', validating: 'Validating...', validateAndCreate: 'Validate & Create Account', - pleaseEnterRefreshToken: 'Please enter Refresh Token' + pleaseEnterRefreshToken: 'Please enter Refresh Token', + pleaseEnterSessionToken: 'Please enter Session Token' }, // Gemini specific gemini: { @@ -1954,6 +1969,7 @@ export default { reAuthorizeAccount: 'Re-Authorize Account', claudeCodeAccount: 'Claude Code Account', openaiAccount: 'OpenAI Account', + soraAccount: 'Sora Account', geminiAccount: 'Gemini Account', antigravityAccount: 'Antigravity Account', inputMethod: 'Input Method', @@ -1979,6 +1995,10 @@ export default { selectTestModel: 'Select Test Model', testModel: 'Test model', testPrompt: 'Prompt: "hi"', + soraTestHint: 'Sora test runs connectivity and capability checks (/backend/me, subscription, Sora2 invite and remaining quota).', + soraTestTarget: 'Target: Sora account capability', + soraTestMode: 'Mode: Connectivity + Capability checks', + soraTestingFlow: 'Running Sora connectivity and capability checks...', // Stats Modal viewStats: 'View Stats', usageStatistics: 'Usage Statistics', @@ -2085,6 +2105,8 @@ export default { actions: 'Actions' }, testConnection: 'Test Connection', + qualityCheck: 'Quality Check', + batchQualityCheck: 'Batch Quality Check', batchTest: 'Test All Proxies', testFailed: 'Failed', latencyFailed: 'Connection failed', @@ -2145,6 +2167,29 @@ export default { proxyWorking: 'Proxy is working!', proxyWorkingWithLatency: 'Proxy is working! Latency: {latency}ms', proxyTestFailed: 'Proxy test failed', + qualityCheckDone: 'Quality check completed: score {score} ({grade})', + qualityCheckFailed: 'Failed to run proxy quality check', + batchQualityDone: + 'Batch quality check completed for {count} proxies: healthy {healthy}, warn {warn}, challenge {challenge}, abnormal {failed}', + batchQualityFailed: 'Batch quality check failed', + batchQualityEmpty: 'No proxies available for quality check', + qualityReportTitle: 'Proxy Quality Report', + qualityGrade: 'Grade {grade}', + qualityExitIP: 'Exit IP', + qualityCountry: 'Exit Region', + qualityBaseLatency: 'Base Latency', + qualityCheckedAt: 'Checked At', + qualityTableTarget: 'Target', + qualityTableStatus: 'Status', + qualityTableLatency: 'Latency', + qualityTableMessage: 'Message', + qualityInline: 'Quality {grade}/{score}', + qualityStatusHealthy: 'Healthy', + qualityStatusPass: 'Pass', + qualityStatusWarn: 'Warn', + qualityStatusFail: 'Fail', + qualityStatusChallenge: 'Challenge', + qualityTargetBase: 'Base Connectivity', failedToLoad: 'Failed to load proxies', failedToCreate: 'Failed to create proxy', failedToUpdate: 'Failed to update proxy', @@ -2385,6 +2430,8 @@ export default { inputTokens: 'Input Tokens', outputTokens: 'Output Tokens', cacheCreationTokens: 'Cache Creation Tokens', + cacheCreation5mTokens: 'Cache Write', + cacheCreation1hTokens: 'Cache Write', cacheReadTokens: 'Cache Read Tokens', failedToLoad: 'Failed to load usage records', billingType: 'Billing Type', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 5e05f0bf..73588ef7 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -582,6 +582,10 @@ export default { description: '查看和分析您的 API 使用历史', costDetails: '成本明细', tokenDetails: 'Token 明细', + cacheTtlOverriddenHint: '缓存 TTL Override 已启用', + cacheTtlOverriddenLabel: 'TTL 替换', + cacheTtlOverridden5m: '按 5m 计费', + cacheTtlOverridden1h: '按 1h 计费', totalRequests: '总请求数', totalTokens: '总 Token', totalCost: '总消费', @@ -1437,6 +1441,7 @@ export default { allPlatforms: '全部平台', allTypes: '全部类型', allStatus: '全部状态', + allGroups: '全部分组', oauthType: 'OAuth', // Schedulable toggle schedulable: '参与调度', @@ -1444,7 +1449,7 @@ export default { schedulableEnabled: '调度已开启', schedulableDisabled: '调度已关闭', failedToToggleSchedulable: '切换调度状态失败', - allGroups: '共 {count} 个分组', + groupCountTotal: '共 {count} 个分组', columns: { name: '名称', platformType: '平台/类型', @@ -1763,6 +1768,12 @@ export default { sessionIdMasking: { label: '会话 ID 伪装', hint: '启用后将在 15 分钟内固定 metadata.user_id 中的 session ID,使上游认为请求来自同一会话' + }, + cacheTTLOverride: { + label: '缓存 TTL 强制替换', + hint: '将所有缓存创建 token 强制按指定的 TTL 类型(5分钟或1小时)计费', + target: '目标 TTL', + targetHint: '选择计费使用的 TTL 类型' } }, expired: '已过期', @@ -1870,9 +1881,13 @@ export default { refreshTokenAuth: '手动输入 RT', refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个', + sessionTokenAuth: '手动输入 ST', + sessionTokenDesc: '输入您已有的 Sora Session Token,支持批量输入(每行一个),系统将自动验证并创建账号。', + sessionTokenPlaceholder: '粘贴您的 Sora Session Token...\n支持多个,每行一个', validating: '验证中...', validateAndCreate: '验证并创建账号', - pleaseEnterRefreshToken: '请输入 Refresh Token' + pleaseEnterRefreshToken: '请输入 Refresh Token', + pleaseEnterSessionToken: '请输入 Session Token' }, // Gemini specific gemini: { @@ -2088,6 +2103,7 @@ export default { reAuthorizeAccount: '重新授权账号', claudeCodeAccount: 'Claude Code 账号', openaiAccount: 'OpenAI 账号', + soraAccount: 'Sora 账号', geminiAccount: 'Gemini 账号', antigravityAccount: 'Antigravity 账号', inputMethod: '输入方式', @@ -2111,6 +2127,10 @@ export default { selectTestModel: '选择测试模型', testModel: '测试模型', testPrompt: '提示词:"hi"', + soraTestHint: 'Sora 测试将执行连通性与能力检测(/backend/me、订阅信息、Sora2 邀请码与剩余额度)。', + soraTestTarget: '检测目标:Sora 账号能力', + soraTestMode: '模式:连通性 + 能力探测', + soraTestingFlow: '执行 Sora 连通性与能力检测...', // Stats Modal viewStats: '查看统计', usageStatistics: '使用统计', @@ -2228,6 +2248,8 @@ export default { noProxiesYet: '暂无代理', createFirstProxy: '添加您的第一个代理以开始使用。', testConnection: '测试连接', + qualityCheck: '质量检测', + batchQualityCheck: '批量质量检测', batchTest: '批量测试', testFailed: '失败', latencyFailed: '链接失败', @@ -2275,6 +2297,28 @@ export default { proxyWorking: '代理连接正常', proxyWorkingWithLatency: '代理连接正常,延迟 {latency}ms', proxyTestFailed: '代理测试失败', + qualityCheckDone: '质量检测完成:评分 {score}({grade})', + qualityCheckFailed: '代理质量检测失败', + batchQualityDone: '批量质量检测完成,共检测 {count} 个;优质 {healthy} 个,告警 {warn} 个,挑战 {challenge} 个,异常 {failed} 个', + batchQualityFailed: '批量质量检测失败', + batchQualityEmpty: '暂无可检测质量的代理', + qualityReportTitle: '代理质量检测报告', + qualityGrade: '等级 {grade}', + qualityExitIP: '出口 IP', + qualityCountry: '出口地区', + qualityBaseLatency: '基础延迟', + qualityCheckedAt: '检测时间', + qualityTableTarget: '检测项', + qualityTableStatus: '状态', + qualityTableLatency: '延迟', + qualityTableMessage: '说明', + qualityInline: '质量 {grade}/{score}', + qualityStatusHealthy: '优质', + qualityStatusPass: '通过', + qualityStatusWarn: '告警', + qualityStatusFail: '失败', + qualityStatusChallenge: '挑战', + qualityTargetBase: '基础连通性', proxyCreatedSuccess: '代理添加成功', proxyUpdatedSuccess: '代理更新成功', proxyDeletedSuccess: '代理删除成功', @@ -2551,6 +2595,8 @@ export default { inputTokens: '输入 Token', outputTokens: '输出 Token', cacheCreationTokens: '缓存创建 Token', + cacheCreation5mTokens: '缓存创建', + cacheCreation1hTokens: '缓存创建', cacheReadTokens: '缓存读取 Token', failedToLoad: '加载使用记录失败', billingType: '计费类型', diff --git a/frontend/src/style.css b/frontend/src/style.css index c1ee8ea5..25631aaf 100644 --- a/frontend/src/style.css +++ b/frontend/src/style.css @@ -243,7 +243,7 @@ } .stat-value { - @apply text-2xl font-bold text-gray-900 dark:text-white; + @apply text-2xl font-bold text-gray-900 dark:text-white truncate; } .stat-label { diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index e5f71520..b6c7dd42 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -512,6 +512,11 @@ export interface Proxy { country_code?: string region?: string city?: string + quality_status?: 'healthy' | 'warn' | 'challenge' | 'failed' + quality_score?: number + quality_grade?: string + quality_summary?: string + quality_checked?: number created_at: string updated_at: string } @@ -524,6 +529,32 @@ export interface ProxyAccountSummary { notes?: string | null } +export interface ProxyQualityCheckItem { + target: string + status: 'pass' | 'warn' | 'fail' | 'challenge' + http_status?: number + latency_ms?: number + message?: string + cf_ray?: string +} + +export interface ProxyQualityCheckResult { + proxy_id: number + score: number + grade: string + summary: string + exit_ip?: string + country?: string + country_code?: string + base_latency_ms?: number + passed_count: number + warn_count: number + failed_count: number + challenge_count: number + checked_at: number + items: ProxyQualityCheckItem[] +} + // Gemini credentials structure for OAuth and API Key authentication export interface GeminiCredentials { // API Key authentication @@ -627,6 +658,10 @@ export interface Account { // 启用后将在15分钟内固定 metadata.user_id 中的 session ID session_id_masking_enabled?: boolean | null + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + cache_ttl_override_enabled?: boolean | null + cache_ttl_override_target?: string | null + // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 @@ -840,6 +875,9 @@ export interface UsageLog { // User-Agent user_agent: string | null + // Cache TTL Override + cache_ttl_overridden: boolean + created_at: string user?: User diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index dc135f4a..236c6f54 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -6,6 +6,7 @@ ({ fetchFn: adminAPI.accounts.list, - initialParams: { platform: '', type: '', status: '', search: '' } + initialParams: { platform: '', type: '', status: '', group: '', search: '' } }) const resetAutoRefreshCache = () => { diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index 9cbf4c58..23d73109 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -55,6 +55,15 @@ {{ t('admin.proxies.testConnection') }} +