feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题

新增功能:
- 新增 Sora 账号管理和 OAuth 认证
- 新增 Sora 视频/图片生成 API 网关
- 新增 Sora 任务调度和缓存机制
- 新增 Sora 使用统计和计费支持
- 前端增加 Sora 平台配置界面

安全修复(代码审核):
- [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击
- [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽
- [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置

BUG 修复(代码审核):
- [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏
- [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏

性能优化(代码审核):
- [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销

技术细节:
- 使用 io.LimitReader 限制所有外部输入的大小
- 添加 urlvalidator 验证防止 SSRF 攻击
- 使用 sync.Map 实现线程安全的包级缓存
- 优化并发槽位管理,添加 releaseAll 模式防止泄漏

影响范围:
- 后端:新增 Sora 相关数据模型、服务、网关和管理接口
- 前端:新增 Sora 平台配置、账号管理和监控界面
- 配置:新增 Sora 相关配置项和环境变量

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-01-29 16:18:38 +08:00
parent bece1b5201
commit 13262a5698
97 changed files with 29541 additions and 68 deletions

View File

@@ -0,0 +1,498 @@
package repository
import (
"context"
"database/sql"
"errors"
"time"
"github.com/Wei-Shaw/sub2api/ent"
dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount"
dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile"
dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask"
dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
)
// SoraAccount
type soraAccountRepository struct {
client *ent.Client
}
func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository {
return &soraAccountRepository{client: client}
}
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
if accountID <= 0 {
return nil, nil
}
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return mapSoraAccount(acc), nil
}
func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) {
if len(accountIDs) == 0 {
return map[int64]*service.SoraAccount{}, nil
}
records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx)
if err != nil {
return nil, err
}
result := make(map[int64]*service.SoraAccount, len(records))
for _, record := range records {
if record == nil {
continue
}
result[record.AccountID] = mapSoraAccount(record)
}
return result, nil
}
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
if accountID <= 0 {
return errors.New("invalid account_id")
}
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
if err != nil && !ent.IsNotFound(err) {
return err
}
if acc == nil {
builder := r.client.SoraAccount.Create().SetAccountID(accountID)
applySoraAccountUpdates(builder.Mutation(), updates)
return builder.Exec(ctx)
}
updater := r.client.SoraAccount.UpdateOneID(acc.ID)
applySoraAccountUpdates(updater.Mutation(), updates)
return updater.Exec(ctx)
}
func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) {
if updates == nil {
return
}
for key, val := range updates {
switch key {
case "access_token":
if v, ok := val.(string); ok {
m.SetAccessToken(v)
}
case "session_token":
if v, ok := val.(string); ok {
m.SetSessionToken(v)
}
case "refresh_token":
if v, ok := val.(string); ok {
m.SetRefreshToken(v)
}
case "client_id":
if v, ok := val.(string); ok {
m.SetClientID(v)
}
case "email":
if v, ok := val.(string); ok {
m.SetEmail(v)
}
case "username":
if v, ok := val.(string); ok {
m.SetUsername(v)
}
case "remark":
if v, ok := val.(string); ok {
m.SetRemark(v)
}
case "plan_type":
if v, ok := val.(string); ok {
m.SetPlanType(v)
}
case "plan_title":
if v, ok := val.(string); ok {
m.SetPlanTitle(v)
}
case "subscription_end":
if v, ok := val.(time.Time); ok {
m.SetSubscriptionEnd(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetSubscriptionEnd(*v)
}
case "sora_supported":
if v, ok := val.(bool); ok {
m.SetSoraSupported(v)
}
case "sora_invite_code":
if v, ok := val.(string); ok {
m.SetSoraInviteCode(v)
}
case "sora_redeemed_count":
if v, ok := val.(int); ok {
m.SetSoraRedeemedCount(v)
}
case "sora_remaining_count":
if v, ok := val.(int); ok {
m.SetSoraRemainingCount(v)
}
case "sora_total_count":
if v, ok := val.(int); ok {
m.SetSoraTotalCount(v)
}
case "sora_cooldown_until":
if v, ok := val.(time.Time); ok {
m.SetSoraCooldownUntil(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetSoraCooldownUntil(*v)
}
case "cooled_until":
if v, ok := val.(time.Time); ok {
m.SetCooledUntil(v)
}
if v, ok := val.(*time.Time); ok && v != nil {
m.SetCooledUntil(*v)
}
case "image_enabled":
if v, ok := val.(bool); ok {
m.SetImageEnabled(v)
}
case "video_enabled":
if v, ok := val.(bool); ok {
m.SetVideoEnabled(v)
}
case "image_concurrency":
if v, ok := val.(int); ok {
m.SetImageConcurrency(v)
}
case "video_concurrency":
if v, ok := val.(int); ok {
m.SetVideoConcurrency(v)
}
case "is_expired":
if v, ok := val.(bool); ok {
m.SetIsExpired(v)
}
}
}
}
func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount {
if acc == nil {
return nil
}
return &service.SoraAccount{
AccountID: acc.AccountID,
AccessToken: derefString(acc.AccessToken),
SessionToken: derefString(acc.SessionToken),
RefreshToken: derefString(acc.RefreshToken),
ClientID: derefString(acc.ClientID),
Email: derefString(acc.Email),
Username: derefString(acc.Username),
Remark: derefString(acc.Remark),
UseCount: acc.UseCount,
PlanType: derefString(acc.PlanType),
PlanTitle: derefString(acc.PlanTitle),
SubscriptionEnd: acc.SubscriptionEnd,
SoraSupported: acc.SoraSupported,
SoraInviteCode: derefString(acc.SoraInviteCode),
SoraRedeemedCount: acc.SoraRedeemedCount,
SoraRemainingCount: acc.SoraRemainingCount,
SoraTotalCount: acc.SoraTotalCount,
SoraCooldownUntil: acc.SoraCooldownUntil,
CooledUntil: acc.CooledUntil,
ImageEnabled: acc.ImageEnabled,
VideoEnabled: acc.VideoEnabled,
ImageConcurrency: acc.ImageConcurrency,
VideoConcurrency: acc.VideoConcurrency,
IsExpired: acc.IsExpired,
CreatedAt: acc.CreatedAt,
UpdatedAt: acc.UpdatedAt,
}
}
func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat {
if stat == nil {
return nil
}
return &service.SoraUsageStat{
AccountID: stat.AccountID,
ImageCount: stat.ImageCount,
VideoCount: stat.VideoCount,
ErrorCount: stat.ErrorCount,
LastErrorAt: stat.LastErrorAt,
TodayImageCount: stat.TodayImageCount,
TodayVideoCount: stat.TodayVideoCount,
TodayErrorCount: stat.TodayErrorCount,
TodayDate: stat.TodayDate,
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
CreatedAt: stat.CreatedAt,
UpdatedAt: stat.UpdatedAt,
}
}
func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile {
if file == nil {
return nil
}
return &service.SoraCacheFile{
ID: int64(file.ID),
TaskID: derefString(file.TaskID),
AccountID: file.AccountID,
UserID: file.UserID,
MediaType: file.MediaType,
OriginalURL: file.OriginalURL,
CachePath: file.CachePath,
CacheURL: file.CacheURL,
SizeBytes: file.SizeBytes,
CreatedAt: file.CreatedAt,
}
}
// SoraUsageStat
type soraUsageStatRepository struct {
client *ent.Client
sql sqlExecutor
}
func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository {
return &soraUsageStatRepository{client: client, sql: sqlDB}
}
func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error {
if accountID <= 0 {
return nil
}
field := "image_count"
todayField := "today_image_count"
if isVideo {
field = "video_count"
todayField = "today_video_count"
}
today := time.Now().UTC().Truncate(24 * time.Hour)
query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " +
"VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " +
"ON CONFLICT (account_id) DO UPDATE SET " +
field + " = sora_usage_stats." + field + " + 1, " +
todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " +
"today_date = $2, consecutive_error_count = 0, updated_at = NOW()"
_, err := r.sql.ExecContext(ctx, query, accountID, today)
return err
}
func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) {
if accountID <= 0 {
return 0, nil
}
today := time.Now().UTC().Truncate(24 * time.Hour)
query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " +
"VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " +
"ON CONFLICT (account_id) DO UPDATE SET " +
"error_count = sora_usage_stats.error_count + 1, " +
"today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " +
"today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " +
"RETURNING consecutive_error_count"
var consecutive int
err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive)
if err != nil {
return 0, err
}
return consecutive, nil
}
func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)).
SetConsecutiveErrorCount(0).
Exec(ctx)
if ent.IsNotFound(err) {
return nil
}
return err
}
func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) {
if accountID <= 0 {
return nil, nil
}
stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return mapSoraUsageStat(stat), nil
}
func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) {
if len(accountIDs) == 0 {
return map[int64]*service.SoraUsageStat{}, nil
}
stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx)
if err != nil {
return nil, err
}
result := make(map[int64]*service.SoraUsageStat, len(stats))
for _, stat := range stats {
if stat == nil {
continue
}
result[stat.AccountID] = mapSoraUsageStat(stat)
}
return result, nil
}
func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) {
query := r.client.SoraUsageStat.Query()
total, err := query.Count(ctx)
if err != nil {
return nil, nil, err
}
stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)).
Limit(params.Limit()).
Offset(params.Offset()).
All(ctx)
if err != nil {
return nil, nil, err
}
result := make([]*service.SoraUsageStat, 0, len(stats))
for _, stat := range stats {
result = append(result, mapSoraUsageStat(stat))
}
return result, paginationResultFromTotal(int64(total), params), nil
}
// SoraTask
type soraTaskRepository struct {
client *ent.Client
}
func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository {
return &soraTaskRepository{client: client}
}
func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error {
if task == nil {
return nil
}
builder := r.client.SoraTask.Create().
SetTaskID(task.TaskID).
SetAccountID(task.AccountID).
SetModel(task.Model).
SetPrompt(task.Prompt).
SetStatus(task.Status).
SetProgress(task.Progress).
SetRetryCount(task.RetryCount)
if task.ResultURLs != "" {
builder.SetResultUrls(task.ResultURLs)
}
if task.ErrorMessage != "" {
builder.SetErrorMessage(task.ErrorMessage)
}
if task.CreatedAt.IsZero() {
builder.SetCreatedAt(time.Now())
} else {
builder.SetCreatedAt(task.CreatedAt)
}
if task.CompletedAt != nil {
builder.SetCompletedAt(*task.CompletedAt)
}
return builder.Exec(ctx)
}
func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error {
if taskID == "" {
return nil
}
builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)).
SetStatus(status).
SetProgress(progress)
if resultURLs != "" {
builder.SetResultUrls(resultURLs)
}
if errorMessage != "" {
builder.SetErrorMessage(errorMessage)
}
if completedAt != nil {
builder.SetCompletedAt(*completedAt)
}
_, err := builder.Save(ctx)
if ent.IsNotFound(err) {
return nil
}
return err
}
// SoraCacheFile
type soraCacheFileRepository struct {
client *ent.Client
}
func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository {
return &soraCacheFileRepository{client: client}
}
func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error {
if file == nil {
return nil
}
builder := r.client.SoraCacheFile.Create().
SetAccountID(file.AccountID).
SetUserID(file.UserID).
SetMediaType(file.MediaType).
SetOriginalURL(file.OriginalURL).
SetCachePath(file.CachePath).
SetCacheURL(file.CacheURL).
SetSizeBytes(file.SizeBytes)
if file.TaskID != "" {
builder.SetTaskID(file.TaskID)
}
if file.CreatedAt.IsZero() {
builder.SetCreatedAt(time.Now())
} else {
builder.SetCreatedAt(file.CreatedAt)
}
return builder.Exec(ctx)
}
func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) {
if limit <= 0 {
return []*service.SoraCacheFile{}, nil
}
records, err := r.client.SoraCacheFile.Query().
Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())).
Limit(limit).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*service.SoraCacheFile, 0, len(records))
for _, record := range records {
if record == nil {
continue
}
result = append(result, mapSoraCacheFile(record))
}
return result, nil
}
func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error {
if len(ids) == 0 {
return nil
}
_, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx)
return err
}

View File

@@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
NewSoraAccountRepository,
NewSoraUsageStatRepository,
NewSoraTaskRepository,
NewSoraCacheFileRepository,
// Cache implementations
NewGatewayCache,