新增功能: - 新增 Sora 账号管理和 OAuth 认证 - 新增 Sora 视频/图片生成 API 网关 - 新增 Sora 任务调度和缓存机制 - 新增 Sora 使用统计和计费支持 - 前端增加 Sora 平台配置界面 安全修复(代码审核): - [SEC-001] 限制媒体下载响应体大小(图片 20MB、视频 200MB),防止 DoS 攻击 - [SEC-002] 限制 SDK API 响应大小(1MB),防止内存耗尽 - [SEC-003] 修复 SSRF 风险,添加 URL 验证并强制使用代理配置 BUG 修复(代码审核): - [BUG-001] 修复 for 循环内 defer 累积导致的资源泄漏 - [BUG-002] 修复图片并发槽位获取失败时已持有锁未释放的永久泄漏 性能优化(代码审核): - [PERF-001] 添加 Sentinel Token 缓存(3 分钟有效期),减少 PoW 计算开销 技术细节: - 使用 io.LimitReader 限制所有外部输入的大小 - 添加 urlvalidator 验证防止 SSRF 攻击 - 使用 sync.Map 实现线程安全的包级缓存 - 优化并发槽位管理,添加 releaseAll 模式防止泄漏 影响范围: - 后端:新增 Sora 相关数据模型、服务、网关和管理接口 - 前端:新增 Sora 平台配置、账号管理和监控界面 - 配置:新增 Sora 相关配置项和环境变量 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
499 lines
14 KiB
Go
499 lines
14 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/ent"
|
|
dbsoraaccount "github.com/Wei-Shaw/sub2api/ent/soraaccount"
|
|
dbsoracachefile "github.com/Wei-Shaw/sub2api/ent/soracachefile"
|
|
dbsoratask "github.com/Wei-Shaw/sub2api/ent/soratask"
|
|
dbsorausagestat "github.com/Wei-Shaw/sub2api/ent/sorausagestat"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
entsql "entgo.io/ent/dialect/sql"
|
|
)
|
|
|
|
// SoraAccount
|
|
|
|
type soraAccountRepository struct {
|
|
client *ent.Client
|
|
}
|
|
|
|
func NewSoraAccountRepository(client *ent.Client) service.SoraAccountRepository {
|
|
return &soraAccountRepository{client: client}
|
|
}
|
|
|
|
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
|
if accountID <= 0 {
|
|
return nil, nil
|
|
}
|
|
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
|
if err != nil {
|
|
if ent.IsNotFound(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return mapSoraAccount(acc), nil
|
|
}
|
|
|
|
func (r *soraAccountRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraAccount, error) {
|
|
if len(accountIDs) == 0 {
|
|
return map[int64]*service.SoraAccount{}, nil
|
|
}
|
|
records, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDIn(accountIDs...)).All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make(map[int64]*service.SoraAccount, len(records))
|
|
for _, record := range records {
|
|
if record == nil {
|
|
continue
|
|
}
|
|
result[record.AccountID] = mapSoraAccount(record)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
|
if accountID <= 0 {
|
|
return errors.New("invalid account_id")
|
|
}
|
|
acc, err := r.client.SoraAccount.Query().Where(dbsoraaccount.AccountIDEQ(accountID)).Only(ctx)
|
|
if err != nil && !ent.IsNotFound(err) {
|
|
return err
|
|
}
|
|
if acc == nil {
|
|
builder := r.client.SoraAccount.Create().SetAccountID(accountID)
|
|
applySoraAccountUpdates(builder.Mutation(), updates)
|
|
return builder.Exec(ctx)
|
|
}
|
|
updater := r.client.SoraAccount.UpdateOneID(acc.ID)
|
|
applySoraAccountUpdates(updater.Mutation(), updates)
|
|
return updater.Exec(ctx)
|
|
}
|
|
|
|
func applySoraAccountUpdates(m *ent.SoraAccountMutation, updates map[string]any) {
|
|
if updates == nil {
|
|
return
|
|
}
|
|
for key, val := range updates {
|
|
switch key {
|
|
case "access_token":
|
|
if v, ok := val.(string); ok {
|
|
m.SetAccessToken(v)
|
|
}
|
|
case "session_token":
|
|
if v, ok := val.(string); ok {
|
|
m.SetSessionToken(v)
|
|
}
|
|
case "refresh_token":
|
|
if v, ok := val.(string); ok {
|
|
m.SetRefreshToken(v)
|
|
}
|
|
case "client_id":
|
|
if v, ok := val.(string); ok {
|
|
m.SetClientID(v)
|
|
}
|
|
case "email":
|
|
if v, ok := val.(string); ok {
|
|
m.SetEmail(v)
|
|
}
|
|
case "username":
|
|
if v, ok := val.(string); ok {
|
|
m.SetUsername(v)
|
|
}
|
|
case "remark":
|
|
if v, ok := val.(string); ok {
|
|
m.SetRemark(v)
|
|
}
|
|
case "plan_type":
|
|
if v, ok := val.(string); ok {
|
|
m.SetPlanType(v)
|
|
}
|
|
case "plan_title":
|
|
if v, ok := val.(string); ok {
|
|
m.SetPlanTitle(v)
|
|
}
|
|
case "subscription_end":
|
|
if v, ok := val.(time.Time); ok {
|
|
m.SetSubscriptionEnd(v)
|
|
}
|
|
if v, ok := val.(*time.Time); ok && v != nil {
|
|
m.SetSubscriptionEnd(*v)
|
|
}
|
|
case "sora_supported":
|
|
if v, ok := val.(bool); ok {
|
|
m.SetSoraSupported(v)
|
|
}
|
|
case "sora_invite_code":
|
|
if v, ok := val.(string); ok {
|
|
m.SetSoraInviteCode(v)
|
|
}
|
|
case "sora_redeemed_count":
|
|
if v, ok := val.(int); ok {
|
|
m.SetSoraRedeemedCount(v)
|
|
}
|
|
case "sora_remaining_count":
|
|
if v, ok := val.(int); ok {
|
|
m.SetSoraRemainingCount(v)
|
|
}
|
|
case "sora_total_count":
|
|
if v, ok := val.(int); ok {
|
|
m.SetSoraTotalCount(v)
|
|
}
|
|
case "sora_cooldown_until":
|
|
if v, ok := val.(time.Time); ok {
|
|
m.SetSoraCooldownUntil(v)
|
|
}
|
|
if v, ok := val.(*time.Time); ok && v != nil {
|
|
m.SetSoraCooldownUntil(*v)
|
|
}
|
|
case "cooled_until":
|
|
if v, ok := val.(time.Time); ok {
|
|
m.SetCooledUntil(v)
|
|
}
|
|
if v, ok := val.(*time.Time); ok && v != nil {
|
|
m.SetCooledUntil(*v)
|
|
}
|
|
case "image_enabled":
|
|
if v, ok := val.(bool); ok {
|
|
m.SetImageEnabled(v)
|
|
}
|
|
case "video_enabled":
|
|
if v, ok := val.(bool); ok {
|
|
m.SetVideoEnabled(v)
|
|
}
|
|
case "image_concurrency":
|
|
if v, ok := val.(int); ok {
|
|
m.SetImageConcurrency(v)
|
|
}
|
|
case "video_concurrency":
|
|
if v, ok := val.(int); ok {
|
|
m.SetVideoConcurrency(v)
|
|
}
|
|
case "is_expired":
|
|
if v, ok := val.(bool); ok {
|
|
m.SetIsExpired(v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func mapSoraAccount(acc *ent.SoraAccount) *service.SoraAccount {
|
|
if acc == nil {
|
|
return nil
|
|
}
|
|
return &service.SoraAccount{
|
|
AccountID: acc.AccountID,
|
|
AccessToken: derefString(acc.AccessToken),
|
|
SessionToken: derefString(acc.SessionToken),
|
|
RefreshToken: derefString(acc.RefreshToken),
|
|
ClientID: derefString(acc.ClientID),
|
|
Email: derefString(acc.Email),
|
|
Username: derefString(acc.Username),
|
|
Remark: derefString(acc.Remark),
|
|
UseCount: acc.UseCount,
|
|
PlanType: derefString(acc.PlanType),
|
|
PlanTitle: derefString(acc.PlanTitle),
|
|
SubscriptionEnd: acc.SubscriptionEnd,
|
|
SoraSupported: acc.SoraSupported,
|
|
SoraInviteCode: derefString(acc.SoraInviteCode),
|
|
SoraRedeemedCount: acc.SoraRedeemedCount,
|
|
SoraRemainingCount: acc.SoraRemainingCount,
|
|
SoraTotalCount: acc.SoraTotalCount,
|
|
SoraCooldownUntil: acc.SoraCooldownUntil,
|
|
CooledUntil: acc.CooledUntil,
|
|
ImageEnabled: acc.ImageEnabled,
|
|
VideoEnabled: acc.VideoEnabled,
|
|
ImageConcurrency: acc.ImageConcurrency,
|
|
VideoConcurrency: acc.VideoConcurrency,
|
|
IsExpired: acc.IsExpired,
|
|
CreatedAt: acc.CreatedAt,
|
|
UpdatedAt: acc.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func mapSoraUsageStat(stat *ent.SoraUsageStat) *service.SoraUsageStat {
|
|
if stat == nil {
|
|
return nil
|
|
}
|
|
return &service.SoraUsageStat{
|
|
AccountID: stat.AccountID,
|
|
ImageCount: stat.ImageCount,
|
|
VideoCount: stat.VideoCount,
|
|
ErrorCount: stat.ErrorCount,
|
|
LastErrorAt: stat.LastErrorAt,
|
|
TodayImageCount: stat.TodayImageCount,
|
|
TodayVideoCount: stat.TodayVideoCount,
|
|
TodayErrorCount: stat.TodayErrorCount,
|
|
TodayDate: stat.TodayDate,
|
|
ConsecutiveErrorCount: stat.ConsecutiveErrorCount,
|
|
CreatedAt: stat.CreatedAt,
|
|
UpdatedAt: stat.UpdatedAt,
|
|
}
|
|
}
|
|
|
|
func mapSoraCacheFile(file *ent.SoraCacheFile) *service.SoraCacheFile {
|
|
if file == nil {
|
|
return nil
|
|
}
|
|
return &service.SoraCacheFile{
|
|
ID: int64(file.ID),
|
|
TaskID: derefString(file.TaskID),
|
|
AccountID: file.AccountID,
|
|
UserID: file.UserID,
|
|
MediaType: file.MediaType,
|
|
OriginalURL: file.OriginalURL,
|
|
CachePath: file.CachePath,
|
|
CacheURL: file.CacheURL,
|
|
SizeBytes: file.SizeBytes,
|
|
CreatedAt: file.CreatedAt,
|
|
}
|
|
}
|
|
|
|
// SoraUsageStat
|
|
|
|
type soraUsageStatRepository struct {
|
|
client *ent.Client
|
|
sql sqlExecutor
|
|
}
|
|
|
|
func NewSoraUsageStatRepository(client *ent.Client, sqlDB *sql.DB) service.SoraUsageStatRepository {
|
|
return &soraUsageStatRepository{client: client, sql: sqlDB}
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) RecordSuccess(ctx context.Context, accountID int64, isVideo bool) error {
|
|
if accountID <= 0 {
|
|
return nil
|
|
}
|
|
field := "image_count"
|
|
todayField := "today_image_count"
|
|
if isVideo {
|
|
field = "video_count"
|
|
todayField = "today_video_count"
|
|
}
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
query := "INSERT INTO sora_usage_stats (account_id, " + field + ", " + todayField + ", today_date, consecutive_error_count, created_at, updated_at) " +
|
|
"VALUES ($1, 1, 1, $2, 0, NOW(), NOW()) " +
|
|
"ON CONFLICT (account_id) DO UPDATE SET " +
|
|
field + " = sora_usage_stats." + field + " + 1, " +
|
|
todayField + " = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats." + todayField + " + 1 ELSE 1 END, " +
|
|
"today_date = $2, consecutive_error_count = 0, updated_at = NOW()"
|
|
_, err := r.sql.ExecContext(ctx, query, accountID, today)
|
|
return err
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) RecordError(ctx context.Context, accountID int64) (int, error) {
|
|
if accountID <= 0 {
|
|
return 0, nil
|
|
}
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
query := "INSERT INTO sora_usage_stats (account_id, error_count, today_error_count, today_date, consecutive_error_count, last_error_at, created_at, updated_at) " +
|
|
"VALUES ($1, 1, 1, $2, 1, NOW(), NOW(), NOW()) " +
|
|
"ON CONFLICT (account_id) DO UPDATE SET " +
|
|
"error_count = sora_usage_stats.error_count + 1, " +
|
|
"today_error_count = CASE WHEN sora_usage_stats.today_date = $2 THEN sora_usage_stats.today_error_count + 1 ELSE 1 END, " +
|
|
"today_date = $2, consecutive_error_count = sora_usage_stats.consecutive_error_count + 1, last_error_at = NOW(), updated_at = NOW() " +
|
|
"RETURNING consecutive_error_count"
|
|
var consecutive int
|
|
err := scanSingleRow(ctx, r.sql, query, []any{accountID, today}, &consecutive)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return consecutive, nil
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) ResetConsecutiveErrors(ctx context.Context, accountID int64) error {
|
|
if accountID <= 0 {
|
|
return nil
|
|
}
|
|
err := r.client.SoraUsageStat.Update().Where(dbsorausagestat.AccountIDEQ(accountID)).
|
|
SetConsecutiveErrorCount(0).
|
|
Exec(ctx)
|
|
if ent.IsNotFound(err) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraUsageStat, error) {
|
|
if accountID <= 0 {
|
|
return nil, nil
|
|
}
|
|
stat, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDEQ(accountID)).Only(ctx)
|
|
if err != nil {
|
|
if ent.IsNotFound(err) {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return mapSoraUsageStat(stat), nil
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) GetByAccountIDs(ctx context.Context, accountIDs []int64) (map[int64]*service.SoraUsageStat, error) {
|
|
if len(accountIDs) == 0 {
|
|
return map[int64]*service.SoraUsageStat{}, nil
|
|
}
|
|
stats, err := r.client.SoraUsageStat.Query().Where(dbsorausagestat.AccountIDIn(accountIDs...)).All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make(map[int64]*service.SoraUsageStat, len(stats))
|
|
for _, stat := range stats {
|
|
if stat == nil {
|
|
continue
|
|
}
|
|
result[stat.AccountID] = mapSoraUsageStat(stat)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (r *soraUsageStatRepository) List(ctx context.Context, params pagination.PaginationParams) ([]*service.SoraUsageStat, *pagination.PaginationResult, error) {
|
|
query := r.client.SoraUsageStat.Query()
|
|
total, err := query.Count(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
stats, err := query.Order(ent.Desc(dbsorausagestat.FieldUpdatedAt)).
|
|
Limit(params.Limit()).
|
|
Offset(params.Offset()).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
result := make([]*service.SoraUsageStat, 0, len(stats))
|
|
for _, stat := range stats {
|
|
result = append(result, mapSoraUsageStat(stat))
|
|
}
|
|
return result, paginationResultFromTotal(int64(total), params), nil
|
|
}
|
|
|
|
// SoraTask
|
|
|
|
type soraTaskRepository struct {
|
|
client *ent.Client
|
|
}
|
|
|
|
func NewSoraTaskRepository(client *ent.Client) service.SoraTaskRepository {
|
|
return &soraTaskRepository{client: client}
|
|
}
|
|
|
|
func (r *soraTaskRepository) Create(ctx context.Context, task *service.SoraTask) error {
|
|
if task == nil {
|
|
return nil
|
|
}
|
|
builder := r.client.SoraTask.Create().
|
|
SetTaskID(task.TaskID).
|
|
SetAccountID(task.AccountID).
|
|
SetModel(task.Model).
|
|
SetPrompt(task.Prompt).
|
|
SetStatus(task.Status).
|
|
SetProgress(task.Progress).
|
|
SetRetryCount(task.RetryCount)
|
|
if task.ResultURLs != "" {
|
|
builder.SetResultUrls(task.ResultURLs)
|
|
}
|
|
if task.ErrorMessage != "" {
|
|
builder.SetErrorMessage(task.ErrorMessage)
|
|
}
|
|
if task.CreatedAt.IsZero() {
|
|
builder.SetCreatedAt(time.Now())
|
|
} else {
|
|
builder.SetCreatedAt(task.CreatedAt)
|
|
}
|
|
if task.CompletedAt != nil {
|
|
builder.SetCompletedAt(*task.CompletedAt)
|
|
}
|
|
return builder.Exec(ctx)
|
|
}
|
|
|
|
func (r *soraTaskRepository) UpdateStatus(ctx context.Context, taskID string, status string, progress float64, resultURLs string, errorMessage string, completedAt *time.Time) error {
|
|
if taskID == "" {
|
|
return nil
|
|
}
|
|
builder := r.client.SoraTask.Update().Where(dbsoratask.TaskIDEQ(taskID)).
|
|
SetStatus(status).
|
|
SetProgress(progress)
|
|
if resultURLs != "" {
|
|
builder.SetResultUrls(resultURLs)
|
|
}
|
|
if errorMessage != "" {
|
|
builder.SetErrorMessage(errorMessage)
|
|
}
|
|
if completedAt != nil {
|
|
builder.SetCompletedAt(*completedAt)
|
|
}
|
|
_, err := builder.Save(ctx)
|
|
if ent.IsNotFound(err) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// SoraCacheFile
|
|
|
|
type soraCacheFileRepository struct {
|
|
client *ent.Client
|
|
}
|
|
|
|
func NewSoraCacheFileRepository(client *ent.Client) service.SoraCacheFileRepository {
|
|
return &soraCacheFileRepository{client: client}
|
|
}
|
|
|
|
func (r *soraCacheFileRepository) Create(ctx context.Context, file *service.SoraCacheFile) error {
|
|
if file == nil {
|
|
return nil
|
|
}
|
|
builder := r.client.SoraCacheFile.Create().
|
|
SetAccountID(file.AccountID).
|
|
SetUserID(file.UserID).
|
|
SetMediaType(file.MediaType).
|
|
SetOriginalURL(file.OriginalURL).
|
|
SetCachePath(file.CachePath).
|
|
SetCacheURL(file.CacheURL).
|
|
SetSizeBytes(file.SizeBytes)
|
|
if file.TaskID != "" {
|
|
builder.SetTaskID(file.TaskID)
|
|
}
|
|
if file.CreatedAt.IsZero() {
|
|
builder.SetCreatedAt(time.Now())
|
|
} else {
|
|
builder.SetCreatedAt(file.CreatedAt)
|
|
}
|
|
return builder.Exec(ctx)
|
|
}
|
|
|
|
func (r *soraCacheFileRepository) ListOldest(ctx context.Context, limit int) ([]*service.SoraCacheFile, error) {
|
|
if limit <= 0 {
|
|
return []*service.SoraCacheFile{}, nil
|
|
}
|
|
records, err := r.client.SoraCacheFile.Query().
|
|
Order(dbsoracachefile.ByCreatedAt(entsql.OrderAsc())).
|
|
Limit(limit).
|
|
All(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := make([]*service.SoraCacheFile, 0, len(records))
|
|
for _, record := range records {
|
|
if record == nil {
|
|
continue
|
|
}
|
|
result = append(result, mapSoraCacheFile(record))
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (r *soraCacheFileRepository) DeleteByIDs(ctx context.Context, ids []int64) error {
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
_, err := r.client.SoraCacheFile.Delete().Where(dbsoracachefile.IDIn(ids...)).Exec(ctx)
|
|
return err
|
|
}
|