feat(sora): 新增 Sora 平台支持并修复高危安全和性能问题
新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
498
backend/internal/repository/sora_repo.go
Normal file
498
backend/internal/repository/sora_repo.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount"
|
||||
dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile"
|
||||
dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask"
|
||||
dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
// SoraAccount
|
||||
|
||||
type soraAccountRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository {
|
||||
return &soraAccountRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return mapSoraAccount(acc), nil
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]*service.SoraAccount{}, nil
|
||||
}
|
||||
records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[int64]*service.SoraAccount, len(records))
|
||||
for _, record := range records {
|
||||
if record == nil {
|
||||
continue
|
||||
}
|
||||
result[record.AccountID] = mapSoraAccount(record)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
||||
if accountID <= 0 {
|
||||
return errors.New("invalid account_id")
|
||||
}
|
||||
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil && !ent.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if acc == nil {
|
||||
builder := r.client.SoraAccount.Create().SetAccountID(accountID)
|
||||
applySoraAccountUpdates(builder.Mutation(), updates)
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
updater := r.client.SoraAccount.UpdateOneID(acc.ID)
|
||||
applySoraAccountUpdates(updater.Mutation(), updates)
|
||||
return updater.Exec(ctx)
|
||||
}
|
||||
|
||||
func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) {
|
||||
if updates == nil {
|
||||
return
|
||||
}
|
||||
for key, val := range updates {
|
||||
switch key {
|
||||
case "access_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetAccessToken(v)
|
||||
}
|
||||
case "session_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetSessionToken(v)
|
||||
}
|
||||
case "refresh_token":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetRefreshToken(v)
|
||||
}
|
||||
case "client_id":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetClientID(v)
|
||||
}
|
||||
case "email":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetEmail(v)
|
||||
}
|
||||
case "username":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetUsername(v)
|
||||
}
|
||||
case "remark":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetRemark(v)
|
||||
}
|
||||
case "plan_type":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetPlanType(v)
|
||||
}
|
||||
case "plan_title":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetPlanTitle(v)
|
||||
}
|
||||
case "subscription_end":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetSubscriptionEnd(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetSubscriptionEnd(*v)
|
||||
}
|
||||
case "sora_supported":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetSoraSupported(v)
|
||||
}
|
||||
case "sora_invite_code":
|
||||
if v, ok := val.(string); ok {
|
||||
m.SetSoraInviteCode(v)
|
||||
}
|
||||
case "sora_redeemed_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraRedeemedCount(v)
|
||||
}
|
||||
case "sora_remaining_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraRemainingCount(v)
|
||||
}
|
||||
case "sora_total_count":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetSoraTotalCount(v)
|
||||
}
|
||||
case "sora_cooldown_until":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetSoraCooldownUntil(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetSoraCooldownUntil(*v)
|
||||
}
|
||||
case "cooled_until":
|
||||
if v, ok := val.(time.Time); ok {
|
||||
m.SetCooledUntil(v)
|
||||
}
|
||||
if v, ok := val.(*time.Time); ok && v != nil {
|
||||
m.SetCooledUntil(*v)
|
||||
}
|
||||
case "image_enabled":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetImageEnabled(v)
|
||||
}
|
||||
case "video_enabled":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetVideoEnabled(v)
|
||||
}
|
||||
case "image_concurrency":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetImageConcurrency(v)
|
||||
}
|
||||
case "video_concurrency":
|
||||
if v, ok := val.(int); ok {
|
||||
m.SetVideoConcurrency(v)
|
||||
}
|
||||
case "is_expired":
|
||||
if v, ok := val.(bool); ok {
|
||||
m.SetIsExpired(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount {
|
||||
if acc == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraAccount{
|
||||
AccountID: acc.AccountID,
|
||||
AccessToken: derefString(acc.AccessToken),
|
||||
SessionToken: derefString(acc.SessionToken),
|
||||
RefreshToken: derefString(acc.RefreshToken),
|
||||
ClientID: derefString(acc.ClientID),
|
||||
Email: derefString(acc.Email),
|
||||
Username: derefString(acc.Username),
|
||||
Remark: derefString(acc.Remark),
|
||||
UseCount: acc.UseCount,
|
||||
PlanType: derefString(acc.PlanType),
|
||||
PlanTitle: derefString(acc.PlanTitle),
|
||||
SubscriptionEnd: acc.SubscriptionEnd,
|
||||
SoraSupported: acc.SoraSupported,
|
||||
SoraInviteCode: derefString(acc.SoraInviteCode),
|
||||
SoraRedeemedCount: acc.SoraRedeemedCount,
|
||||
SoraRemainingCount: acc.SoraRemainingCount,
|
||||
SoraTotalCount: acc.SoraTotalCount,
|
||||
SoraCooldownUntil: acc.SoraCooldownUntil,
|
||||
CooledUntil: acc.CooledUntil,
|
||||
ImageEnabled: acc.ImageEnabled,
|
||||
VideoEnabled: acc.VideoEnabled,
|
||||
ImageConcurrency: acc.ImageConcurrency,
|
||||
VideoConcurrency: acc.VideoConcurrency,
|
||||
IsExpired: acc.IsExpired,
|
||||
CreatedAt: acc.CreatedAt,
|
||||
UpdatedAt: acc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat {
|
||||
if stat == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraUsageStat{
|
||||
AccountID: stat.AccountID,
|
||||
ImageCount: stat.ImageCount,
|
||||
VideoCount: stat.VideoCount,
|
||||
ErrorCount: stat.ErrorCount,
|
||||
LastErrorAt: stat.LastErrorAt,
|
||||
TodayImageCount: stat.TodayImageCount,
|
||||
TodayVideoCount: stat.TodayVideoCount,
|
||||
TodayErrorCount: stat.TodayErrorCount,
|
||||
TodayDate: stat.TodayDate,
|
||||
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
|
||||
CreatedAt: stat.CreatedAt,
|
||||
UpdatedAt: stat.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.SoraCacheFile{
|
||||
ID: int64(file.ID),
|
||||
TaskID: derefString(file.TaskID),
|
||||
AccountID: file.AccountID,
|
||||
UserID: file.UserID,
|
||||
MediaType: file.MediaType,
|
||||
OriginalURL: file.OriginalURL,
|
||||
CachePath: file.CachePath,
|
||||
CacheURL: file.CacheURL,
|
||||
SizeBytes: file.SizeBytes,
|
||||
CreatedAt: file.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// SoraUsageStat
|
||||
|
||||
type soraUsageStatRepository struct {
|
||||
client *ent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository {
|
||||
return &soraUsageStatRepository{client: client, sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
field := "image_count"
|
||||
todayField := "today_image_count"
|
||||
if isVideo {
|
||||
field = "video_count"
|
||||
todayField = "today_video_count"
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " +
|
||||
"VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " +
|
||||
"ON CONFLICT (account_id) DO UPDATE SET " +
|
||||
field + " = sora_usage_stats." + field + " + 1, " +
|
||||
todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " +
|
||||
"today_date = $2, consecutive_error_count = 0, updated_at = NOW()"
|
||||
_, err := r.sql.ExecContext(ctx, query, accountID, today)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) {
|
||||
if accountID <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " +
|
||||
"VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " +
|
||||
"ON CONFLICT (account_id) DO UPDATE SET " +
|
||||
"error_count = sora_usage_stats.error_count + 1, " +
|
||||
"today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " +
|
||||
"today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " +
|
||||
"RETURNING consecutive_error_count"
|
||||
var consecutive int
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return consecutive, nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)).
|
||||
SetConsecutiveErrorCount(0).
|
||||
Exec(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) {
|
||||
if accountID <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return mapSoraUsageStat(stat), nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]*service.SoraUsageStat{}, nil
|
||||
}
|
||||
stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make(map[int64]*service.SoraUsageStat, len(stats))
|
||||
for _, stat := range stats {
|
||||
if stat == nil {
|
||||
continue
|
||||
}
|
||||
result[stat.AccountID] = mapSoraUsageStat(stat)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) {
|
||||
query := r.client.SoraUsageStat.Query()
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)).
|
||||
Limit(params.Limit()).
|
||||
Offset(params.Offset()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result := make([]*service.SoraUsageStat, 0, len(stats))
|
||||
for _, stat := range stats {
|
||||
result = append(result, mapSoraUsageStat(stat))
|
||||
}
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SoraTask
|
||||
|
||||
type soraTaskRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository {
|
||||
return &soraTaskRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraTask.Create().
|
||||
SetTaskID(task.TaskID).
|
||||
SetAccountID(task.AccountID).
|
||||
SetModel(task.Model).
|
||||
SetPrompt(task.Prompt).
|
||||
SetStatus(task.Status).
|
||||
SetProgress(task.Progress).
|
||||
SetRetryCount(task.RetryCount)
|
||||
if task.ResultURLs != "" {
|
||||
builder.SetResultUrls(task.ResultURLs)
|
||||
}
|
||||
if task.ErrorMessage != "" {
|
||||
builder.SetErrorMessage(task.ErrorMessage)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
builder.SetCreatedAt(time.Now())
|
||||
} else {
|
||||
builder.SetCreatedAt(task.CreatedAt)
|
||||
}
|
||||
if task.CompletedAt != nil {
|
||||
builder.SetCompletedAt(*task.CompletedAt)
|
||||
}
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error {
|
||||
if taskID == "" {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)).
|
||||
SetStatus(status).
|
||||
SetProgress(progress)
|
||||
if resultURLs != "" {
|
||||
builder.SetResultUrls(resultURLs)
|
||||
}
|
||||
if errorMessage != "" {
|
||||
builder.SetErrorMessage(errorMessage)
|
||||
}
|
||||
if completedAt != nil {
|
||||
builder.SetCompletedAt(*completedAt)
|
||||
}
|
||||
_, err := builder.Save(ctx)
|
||||
if ent.IsNotFound(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SoraCacheFile
|
||||
|
||||
type soraCacheFileRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository {
|
||||
return &soraCacheFileRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error {
|
||||
if file == nil {
|
||||
return nil
|
||||
}
|
||||
builder := r.client.SoraCacheFile.Create().
|
||||
SetAccountID(file.AccountID).
|
||||
SetUserID(file.UserID).
|
||||
SetMediaType(file.MediaType).
|
||||
SetOriginalURL(file.OriginalURL).
|
||||
SetCachePath(file.CachePath).
|
||||
SetCacheURL(file.CacheURL).
|
||||
SetSizeBytes(file.SizeBytes)
|
||||
if file.TaskID != "" {
|
||||
builder.SetTaskID(file.TaskID)
|
||||
}
|
||||
if file.CreatedAt.IsZero() {
|
||||
builder.SetCreatedAt(time.Now())
|
||||
} else {
|
||||
builder.SetCreatedAt(file.CreatedAt)
|
||||
}
|
||||
return builder.Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) {
|
||||
if limit <= 0 {
|
||||
return []*service.SoraCacheFile{}, nil
|
||||
}
|
||||
records, err := r.client.SoraCacheFile.Query().
|
||||
Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())).
|
||||
Limit(limit).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]*service.SoraCacheFile, 0, len(records))
|
||||
for _, record := range records {
|
||||
if record == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, mapSoraCacheFile(record))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
@@ -64,6 +64,10 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
NewSoraAccountRepository,
|
||||
NewSoraUsageStatRepository,
|
||||
NewSoraTaskRepository,
|
||||
NewSoraCacheFileRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
Reference in New Issue
Block a user