revert: completely remove all Sora functionality
This commit is contained in:
@@ -1692,20 +1692,13 @@ func itoa(v int) string {
|
||||
}
|
||||
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||
//
|
||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
||||
//
|
||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||
//
|
||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||
dbaccount.DeletedAtIsNil(),
|
||||
func(s *entsql.Selector) {
|
||||
path := sqljson.Path(key)
|
||||
|
||||
@@ -155,10 +155,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldSoraImagePrice360,
|
||||
group.FieldSoraImagePrice540,
|
||||
group.FieldSoraVideoPricePerRequest,
|
||||
group.FieldSoraVideoPricePerRequestHd,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
@@ -617,8 +613,6 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
@@ -647,11 +641,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
|
||||
@@ -49,17 +49,12 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
@@ -122,15 +117,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
|
||||
@@ -158,30 +158,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
|
||||
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
@@ -276,7 +252,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
||||
wantClientID := openai.SoraClientID
|
||||
wantClientID := "custom-exchange-client-id"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
|
||||
//
|
||||
// 设计说明:
|
||||
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
|
||||
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
|
||||
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
|
||||
type soraAccountRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
|
||||
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
|
||||
return &soraAccountRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// Upsert 创建或更新 Sora 账号扩展信息
|
||||
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
|
||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
||||
accessToken, accessOK := updates["access_token"].(string)
|
||||
refreshToken, refreshOK := updates["refresh_token"].(string)
|
||||
sessionToken, sessionOK := updates["session_token"].(string)
|
||||
|
||||
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
|
||||
if !sessionOK {
|
||||
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
|
||||
}
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_accounts
|
||||
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
|
||||
updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`, accountID, sessionToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
||||
ON CONFLICT (account_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
|
||||
updated_at = NOW()
|
||||
`, accountID, accessToken, refreshToken, sessionToken)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
|
||||
FROM sora_accounts
|
||||
WHERE account_id = $1
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, nil // 记录不存在
|
||||
}
|
||||
|
||||
var sa service.SoraAccount
|
||||
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sa, nil
|
||||
}
|
||||
|
||||
// Delete 删除 Sora 账号扩展信息
|
||||
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM sora_accounts WHERE account_id = $1
|
||||
`, accountID)
|
||||
return err
|
||||
}
|
||||
@@ -1,419 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_generations 表。
|
||||
type soraGenerationRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
|
||||
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
|
||||
return &soraGenerationRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
|
||||
func (r *soraGenerationRepository) CreatePendingWithLimit(
|
||||
ctx context.Context,
|
||||
gen *service.SoraGeneration,
|
||||
activeStatuses []string,
|
||||
maxActive int64,
|
||||
) error {
|
||||
if gen == nil {
|
||||
return fmt.Errorf("generation is nil")
|
||||
}
|
||||
if maxActive <= 0 {
|
||||
return r.Create(ctx, gen)
|
||||
}
|
||||
if len(activeStatuses) == 0 {
|
||||
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
|
||||
}
|
||||
|
||||
tx, err := r.sql.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
|
||||
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(activeStatuses))
|
||||
args := make([]any, 0, 1+len(activeStatuses))
|
||||
args = append(args, gen.UserID)
|
||||
for i, s := range activeStatuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
countQuery := fmt.Sprintf(
|
||||
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
|
||||
strings.Join(placeholders, ","),
|
||||
)
|
||||
var activeCount int64
|
||||
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if activeCount >= maxActive {
|
||||
return service.ErrSoraGenerationConcurrencyLimit
|
||||
}
|
||||
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
if err := tx.QueryRowContext(ctx, `
|
||||
INSERT INTO sora_generations (
|
||||
user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING id, created_at
|
||||
`,
|
||||
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
|
||||
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
|
||||
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
err := r.sql.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations WHERE id = $1
|
||||
`, id).Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("生成记录不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
return gen, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
|
||||
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
|
||||
|
||||
var completedAt *time.Time
|
||||
if gen.CompletedAt != nil {
|
||||
completedAt = gen.CompletedAt
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations SET
|
||||
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
|
||||
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
|
||||
error_message = $9, completed_at = $10
|
||||
WHERE id = $1
|
||||
`,
|
||||
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
|
||||
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
|
||||
gen.ErrorMessage, completedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
|
||||
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, upstream_task_id = $3
|
||||
WHERE id = $1 AND status = $4
|
||||
`,
|
||||
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
|
||||
func (r *soraGenerationRepository) UpdateCompletedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
media_url = $3,
|
||||
media_urls = $4,
|
||||
file_size_bytes = $5,
|
||||
storage_type = $6,
|
||||
s3_object_keys = $7,
|
||||
error_message = '',
|
||||
completed_at = $8
|
||||
WHERE id = $1 AND status IN ($9, $10)
|
||||
`,
|
||||
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
|
||||
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
|
||||
func (r *soraGenerationRepository) UpdateFailedIfActive(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
errMsg string,
|
||||
completedAt time.Time,
|
||||
) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2,
|
||||
error_message = $3,
|
||||
completed_at = $4
|
||||
WHERE id = $1 AND status IN ($5, $6)
|
||||
`,
|
||||
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
|
||||
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET status = $2, completed_at = $3
|
||||
WHERE id = $1 AND status IN ($4, $5)
|
||||
`,
|
||||
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
|
||||
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
mediaURL string,
|
||||
mediaURLs []string,
|
||||
storageType string,
|
||||
s3Keys []string,
|
||||
fileSizeBytes int64,
|
||||
) (bool, error) {
|
||||
mediaURLsJSON, _ := json.Marshal(mediaURLs)
|
||||
s3KeysJSON, _ := json.Marshal(s3Keys)
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_generations
|
||||
SET media_url = $2,
|
||||
media_urls = $3,
|
||||
file_size_bytes = $4,
|
||||
storage_type = $5,
|
||||
s3_object_keys = $6
|
||||
WHERE id = $1 AND status = $7
|
||||
`,
|
||||
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
|
||||
// 构建 WHERE 条件
|
||||
conditions := []string{"user_id = $1"}
|
||||
args := []any{params.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if params.Status != "" {
|
||||
// 支持逗号分隔的多状态
|
||||
statuses := strings.Split(params.Status, ",")
|
||||
placeholders := make([]string, len(statuses))
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.StorageType != "" {
|
||||
storageTypes := strings.Split(params.StorageType, ",")
|
||||
placeholders := make([]string, len(storageTypes))
|
||||
for i, s := range storageTypes {
|
||||
placeholders[i] = fmt.Sprintf("$%d", argIdx)
|
||||
args = append(args, strings.TrimSpace(s))
|
||||
argIdx++
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
|
||||
}
|
||||
if params.MediaType != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
|
||||
args = append(args, params.MediaType)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := "WHERE " + strings.Join(conditions, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
|
||||
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (params.Page - 1) * params.PageSize
|
||||
listQuery := fmt.Sprintf(`
|
||||
SELECT id, user_id, api_key_id, model, prompt, media_type,
|
||||
status, media_url, media_urls, file_size_bytes,
|
||||
storage_type, s3_object_keys, upstream_task_id, error_message,
|
||||
created_at, completed_at
|
||||
FROM sora_generations %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIdx, argIdx+1)
|
||||
args = append(args, params.PageSize, offset)
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
var results []*service.SoraGeneration
|
||||
for rows.Next() {
|
||||
gen := &service.SoraGeneration{}
|
||||
var mediaURLsJSON, s3KeysJSON []byte
|
||||
var completedAt sql.NullTime
|
||||
var apiKeyID sql.NullInt64
|
||||
|
||||
if err := rows.Scan(
|
||||
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
|
||||
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
|
||||
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
|
||||
&gen.CreatedAt, &completedAt,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if apiKeyID.Valid {
|
||||
gen.APIKeyID = &apiKeyID.Int64
|
||||
}
|
||||
if completedAt.Valid {
|
||||
gen.CompletedAt = &completedAt.Time
|
||||
}
|
||||
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
|
||||
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
|
||||
results = append(results, gen)
|
||||
}
|
||||
|
||||
return results, total, rows.Err()
|
||||
}
|
||||
|
||||
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
|
||||
if len(statuses) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
placeholders := make([]string, len(statuses))
|
||||
args := []any{userID}
|
||||
for i, s := range statuses {
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
args = append(args, s)
|
||||
}
|
||||
|
||||
var count int64
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
|
||||
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
@@ -62,7 +62,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
@@ -145,8 +144,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
|
||||
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
@@ -376,65 +373,6 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
|
||||
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
|
||||
WHERE id = $1
|
||||
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
|
||||
if err == nil {
|
||||
return newUsed, nil
|
||||
}
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
// 区分用户不存在和配额冲突
|
||||
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
|
||||
if existsErr != nil {
|
||||
return 0, existsErr
|
||||
}
|
||||
if !exists {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, service.ErrSoraStorageQuotaExceeded
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
|
||||
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
|
||||
if deltaBytes <= 0 {
|
||||
user, err := r.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return user.SoraStorageUsedBytes, nil
|
||||
}
|
||||
var newUsed int64
|
||||
err := scanSingleRow(ctx, r.sql, `
|
||||
UPDATE users
|
||||
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
|
||||
WHERE id = $1
|
||||
RETURNING sora_storage_used_bytes
|
||||
`, []any{userID, deltaBytes}, &newUsed)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return newUsed, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
@@ -53,7 +53,6 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewScheduledTestPlanRepository, // 定时测试计划仓储
|
||||
NewScheduledTestResultRepository, // 定时测试结果仓储
|
||||
NewProxyRepository,
|
||||
|
||||
Reference in New Issue
Block a user