fix(merge): 解决与 main 分支的配置冲突
- 合并 main 分支的上游错误日志配置 - 保留调度配置 - 合并 beta header 和 failover 配置
This commit is contained in:
32
backend/internal/repository/db_pool.go
Normal file
32
backend/internal/repository/db_pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
50
backend/internal/repository/db_pool_test.go
Normal file
50
backend/internal/repository/db_pool_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
69
backend/internal/repository/ent.go
Normal file
69
backend/internal/repository/ent.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
|
||||
)
|
||||
|
||||
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
|
||||
//
|
||||
// 该函数执行以下操作:
|
||||
// 1. 初始化全局时区设置,确保时间处理一致性
|
||||
// 2. 建立 PostgreSQL 数据库连接
|
||||
// 3. 自动执行数据库迁移,确保 schema 与代码同步
|
||||
// 4. 创建并返回 Ent 客户端实例
|
||||
//
|
||||
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
|
||||
//
|
||||
// 返回:
|
||||
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
|
||||
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
|
||||
// - error: 初始化过程中的错误
|
||||
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
|
||||
// 这对于跨时区部署和日志时间戳的一致性至关重要。
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建包含时区信息的数据库连接字符串 (DSN)。
|
||||
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
|
||||
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
|
||||
|
||||
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
|
||||
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
|
||||
drv, err := entsql.Open(dialect.Postgres, dsn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
applyDBPoolSettings(drv.DB(), cfg)
|
||||
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
|
||||
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
|
||||
log.Printf("failed to open sql db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
if err := ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
log.Printf("failed to apply db migrations: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
198
backend/internal/repository/migrations_runner.go
Normal file
198
backend/internal/repository/migrations_runner.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
)
|
||||
|
||||
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
|
||||
// 该表用于跟踪已应用的迁移文件及其校验和。
|
||||
// - filename: 迁移文件名,作为主键唯一标识每个迁移
|
||||
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
|
||||
// - applied_at: 迁移应用时间戳
|
||||
const schemaMigrationsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
checksum TEXT NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
//
|
||||
// 该函数可以在每次应用启动时安全调用:
|
||||
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
|
||||
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
|
||||
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文,用于超时控制和取消
|
||||
// - db: 数据库连接
|
||||
//
|
||||
// 返回:
|
||||
// - error: 迁移过程中的任何错误
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
return applyMigrationsFS(ctx, db, migrations.FS)
|
||||
}
|
||||
|
||||
// applyMigrationsFS 是迁移执行的核心实现。
|
||||
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
|
||||
//
|
||||
// 迁移执行流程:
|
||||
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
|
||||
// 2. 确保 schema_migrations 表存在
|
||||
// 3. 按文件名排序读取所有 .sql 文件
|
||||
// 4. 对于每个迁移文件:
|
||||
// - 计算文件内容的 SHA256 校验和
|
||||
// - 检查该迁移是否已应用(通过 filename 查询)
|
||||
// - 如果已应用,验证校验和是否匹配
|
||||
// - 如果未应用,在事务中执行迁移并记录
|
||||
// 5. 释放 Advisory Lock
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - db: 数据库连接
|
||||
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
|
||||
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
|
||||
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
|
||||
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
|
||||
if err := pgAdvisoryLock(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// 无论迁移是否成功,都要释放锁。
|
||||
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
|
||||
_ = pgAdvisoryUnlock(context.Background(), db)
|
||||
}()
|
||||
|
||||
// 创建迁移记录表(如果不存在)。
|
||||
// 该表记录所有已应用的迁移及其校验和。
|
||||
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list migrations: %w", err)
|
||||
}
|
||||
sort.Strings(files) // 确保按文件名顺序执行迁移
|
||||
|
||||
for _, name := range files {
|
||||
// 读取迁移文件内容
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
if content == "" {
|
||||
continue // 跳过空文件
|
||||
}
|
||||
|
||||
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
|
||||
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
// 检查该迁移是否已经应用
|
||||
var existing string
|
||||
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
|
||||
if rowErr == nil {
|
||||
// 迁移已应用,验证校验和是否匹配
|
||||
if existing != checksum {
|
||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||
// 正确的做法是创建新的迁移文件来进行变更。
|
||||
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
|
||||
}
|
||||
continue // 迁移已应用且校验和匹配,跳过
|
||||
}
|
||||
if !errors.Is(rowErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check migration %s: %w", name, rowErr)
|
||||
}
|
||||
|
||||
// 迁移未应用,在事务中执行。
|
||||
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 执行迁移 SQL
|
||||
if _, err := tx.ExecContext(ctx, content); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("apply migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 记录迁移已完成,保存文件名和校验和
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("record migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("commit migration %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
|
||||
ticker := time.NewTicker(migrationsLockRetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
var locked bool
|
||||
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
|
||||
return fmt.Errorf("acquire migrations lock: %w", err)
|
||||
}
|
||||
if locked {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
|
||||
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
|
||||
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release migrations lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
|
||||
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
|
||||
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
|
||||
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
|
||||
|
||||
// schema_migrations should have at least the current migration set.
|
||||
var applied int
|
||||
|
||||
39
backend/internal/repository/redis.go
Normal file
39
backend/internal/repository/redis.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// InitRedis 初始化 Redis 客户端
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
|
||||
// 1. 默认连接池大小可能不足以支撑高并发
|
||||
// 2. 无超时控制可能导致慢操作阻塞
|
||||
//
|
||||
// 新实现支持可配置的连接池和超时参数:
|
||||
// 1. PoolSize: 控制最大并发连接数(默认 128)
|
||||
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
|
||||
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
|
||||
func InitRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(buildRedisOptions(cfg))
|
||||
}
|
||||
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
|
||||
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
|
||||
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
}
|
||||
35
backend/internal/repository/redis_test.go
Normal file
35
backend/internal/repository/redis_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildRedisOptions(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
DB: 2,
|
||||
DialTimeoutSeconds: 5,
|
||||
ReadTimeoutSeconds: 3,
|
||||
WriteTimeoutSeconds: 4,
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 10,
|
||||
},
|
||||
}
|
||||
|
||||
opts := buildRedisOptions(cfg)
|
||||
require.Equal(t, "localhost:6379", opts.Addr)
|
||||
require.Equal(t, "secret", opts.Password)
|
||||
require.Equal(t, 2, opts.DB)
|
||||
require.Equal(t, 5*time.Second, opts.DialTimeout)
|
||||
require.Equal(t, 3*time.Second, opts.ReadTimeout)
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
}
|
||||
@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
|
||||
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
|
||||
}
|
||||
|
||||
// IncrementUsage 原子性地累加用量并校验限额。
|
||||
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
|
||||
// 当更新失败时,会执行额外查询确定具体超出的限额类型。
|
||||
// IncrementUsage 原子性地累加订阅用量。
|
||||
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
|
||||
// 此处仅负责记录实际消费,确保消费数据的完整性。
|
||||
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
|
||||
// NULL 限额表示无限制
|
||||
const atomicUpdateSQL = `
|
||||
const updateSQL = `
|
||||
UPDATE user_subscriptions us
|
||||
SET
|
||||
daily_usage_usd = us.daily_usage_usd + $1,
|
||||
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
AND us.deleted_at IS NULL
|
||||
AND us.group_id = g.id
|
||||
AND g.deleted_at IS NULL
|
||||
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
|
||||
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
|
||||
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
|
||||
`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
|
||||
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
|
||||
}
|
||||
|
||||
if affected > 0 {
|
||||
return nil // 更新成功
|
||||
return nil
|
||||
}
|
||||
|
||||
// affected == 0:可能是订阅不存在、分组已删除、或限额超出
|
||||
// 执行额外查询确定具体原因
|
||||
return r.checkIncrementFailureReason(ctx, id, costUSD)
|
||||
}
|
||||
|
||||
// checkIncrementFailureReason 查询更新失败的具体原因
|
||||
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
|
||||
const checkSQL = `
|
||||
SELECT
|
||||
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
|
||||
WHEN g.id IS NULL THEN 'subscription_not_found'
|
||||
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
|
||||
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
|
||||
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
|
||||
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
|
||||
ELSE 'unknown'
|
||||
END AS reason
|
||||
FROM user_subscriptions us
|
||||
LEFT JOIN groups g ON us.group_id = g.id
|
||||
WHERE us.id = $2
|
||||
`
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
var reason string
|
||||
if err := rows.Scan(&reason); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch reason {
|
||||
case "subscription_not_found", "subscription_deleted", "group_deleted":
|
||||
return service.ErrSubscriptionNotFound
|
||||
case "daily_exceeded":
|
||||
return service.ErrDailyLimitExceeded
|
||||
case "weekly_exceeded":
|
||||
return service.ErrWeeklyLimitExceeded
|
||||
case "monthly_exceeded":
|
||||
return service.ErrMonthlyLimitExceeded
|
||||
default:
|
||||
// unknown 情况理论上不应发生,但作为兜底返回
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
// affected == 0:订阅不存在或已删除
|
||||
return service.ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
|
||||
@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
// --- 限额检查与软删除过滤测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
create := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeSubscription)
|
||||
|
||||
if daily != nil {
|
||||
create.SetDailyLimitUsd(*daily)
|
||||
}
|
||||
if weekly != nil {
|
||||
create.SetWeeklyLimitUsd(*weekly)
|
||||
}
|
||||
if monthly != nil {
|
||||
create.SetMonthlyLimitUsd(*monthly)
|
||||
}
|
||||
|
||||
g, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create group with limits")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
|
||||
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
|
||||
dailyLimit := 10.0
|
||||
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 先增加 9.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 2.0,会超过 10.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
|
||||
s.Require().Error(err, "should fail when daily limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
|
||||
|
||||
// 验证用量没有变化
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
|
||||
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
|
||||
weeklyLimit := 50.0
|
||||
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 增加 45.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 10.0,会超过 50.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
|
||||
s.Require().Error(err, "should fail when weekly limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
|
||||
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
|
||||
monthlyLimit := 100.0
|
||||
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 增加 90.0,应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
|
||||
s.Require().NoError(err, "first increment should succeed")
|
||||
|
||||
// 再增加 20.0,会超过 100.0 限额,应该失败
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
|
||||
s.Require().Error(err, "should fail when monthly limit exceeded")
|
||||
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
|
||||
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 应该可以增加任意金额
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
|
||||
s.Require().NoError(err, "should succeed without limits")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
|
||||
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
|
||||
dailyLimit := 10.0
|
||||
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 正好达到限额应该成功
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
|
||||
s.Require().NoError(err, "should succeed at exact limit")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
// --- 软删除过滤测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
|
||||
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
|
||||
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
|
||||
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额
|
||||
group := s.mustCreateGroup("g-concurrent")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
const numGoroutines = 10
|
||||
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
|
||||
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
|
||||
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
|
||||
dailyLimit := 5.0
|
||||
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
|
||||
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
|
||||
const numAttempts = 10
|
||||
const incrementPerAttempt = 1.0
|
||||
|
||||
successCount := 0
|
||||
for i := 0; i < numAttempts; i++ {
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
|
||||
if err == nil {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
// 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额)
|
||||
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
|
||||
|
||||
// 验证最终用量等于限额
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
|
||||
baseClient := testEntClient(s.T())
|
||||
tx, err := baseClient.Tx(context.Background())
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/wire"
|
||||
@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet(
|
||||
NewOpenAIOAuthClient,
|
||||
NewGeminiOAuthClient,
|
||||
NewGeminiCliCodeAssistClient,
|
||||
|
||||
ProvideEnt,
|
||||
ProvideSQLDB,
|
||||
ProvideRedis,
|
||||
)
|
||||
|
||||
// ProvideEnt 为依赖注入提供 Ent 客户端。
|
||||
//
|
||||
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
|
||||
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*ent.Client
|
||||
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
|
||||
client, _, err := InitEnt(cfg)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
|
||||
//
|
||||
// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询),
|
||||
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
|
||||
//
|
||||
// 设计说明:
|
||||
// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问
|
||||
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
|
||||
//
|
||||
// 依赖:*ent.Client
|
||||
// 提供:*sql.DB
|
||||
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("nil ent client")
|
||||
}
|
||||
// 从 Ent 客户端获取底层驱动
|
||||
drv, ok := client.Driver().(*entsql.Driver)
|
||||
if !ok {
|
||||
return nil, errors.New("ent driver does not expose *sql.DB")
|
||||
}
|
||||
// 返回驱动持有的 sql.DB 实例
|
||||
return drv.DB(), nil
|
||||
}
|
||||
|
||||
// ProvideRedis 为依赖注入提供 Redis 客户端。
|
||||
//
|
||||
// Redis 用于:
|
||||
// - 分布式锁(如并发控制)
|
||||
// - 缓存(如用户会话、API 响应缓存)
|
||||
// - 速率限制
|
||||
// - 实时统计数据
|
||||
//
|
||||
// 依赖:config.Config
|
||||
// 提供:*redis.Client
|
||||
func ProvideRedis(cfg *config.Config) *redis.Client {
|
||||
return InitRedis(cfg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user