merge: 合并 test 分支到 test-dev,解决冲突

解决的冲突文件:
- wire_gen.go: 合并 ConcurrencyService/CRSSyncService 参数和 userAttributeHandler
- gateway_handler.go: 合并 pkg/errors 和 antigravity 导入
- gateway_service.go: 合并 validateUpstreamBaseURL 和 GetAvailableModels
- config.example.yaml: 合并 billing/turnstile 配置和额外 gateway 选项

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-01-03 11:36:31 +08:00
176 changed files with 27680 additions and 1952 deletions

View File

@@ -124,6 +124,90 @@ func (r *accountRepository) GetByID(ctx context.Context, id int64) (*service.Acc
return &accounts[0], nil
}
func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
if len(ids) == 0 {
return []*service.Account{}, nil
}
// De-duplicate while preserving order of first occurrence.
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return []*service.Account{}, nil
}
entAccounts, err := r.client.Account.
Query().
Where(dbaccount.IDIn(uniqueIDs...)).
WithProxy().
All(ctx)
if err != nil {
return nil, err
}
if len(entAccounts) == 0 {
return []*service.Account{}, nil
}
accountIDs := make([]int64, 0, len(entAccounts))
entByID := make(map[int64]*dbent.Account, len(entAccounts))
for _, acc := range entAccounts {
entByID[acc.ID] = acc
accountIDs = append(accountIDs, acc.ID)
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
}
outByID := make(map[int64]*service.Account, len(entAccounts))
for _, entAcc := range entAccounts {
out := accountEntityToService(entAcc)
if out == nil {
continue
}
// Prefer the preloaded proxy edge when available.
if entAcc.Edges.Proxy != nil {
out.Proxy = proxyEntityToService(entAcc.Edges.Proxy)
}
if groups, ok := groupsByAccount[entAcc.ID]; ok {
out.Groups = groups
}
if groupIDs, ok := groupIDsByAccount[entAcc.ID]; ok {
out.GroupIDs = groupIDs
}
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
outByID[entAcc.ID] = out
}
// Preserve input order (first occurrence), and ignore missing IDs.
out := make([]*service.Account, 0, len(uniqueIDs))
for _, id := range uniqueIDs {
if _, ok := entByID[id]; !ok {
continue
}
if acc, ok := outByID[id]; ok && acc != nil {
out = append(out, acc)
}
}
return out, nil
}
// ExistsByID 检查指定 ID 的账号是否存在。
// 相比 GetByID此方法性能更优因为
// - 使用 Exist() 方法生成 SELECT EXISTS 查询,只返回布尔值

View File

@@ -294,7 +294,6 @@ func userEntityToService(u *dbent.User) *service.User {
ID: u.ID,
Email: u.Email,
Username: u.Username,
Wechat: u.Wechat,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,

View File

@@ -2,7 +2,9 @@ package repository
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
@@ -27,6 +29,8 @@ const (
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
@@ -112,33 +116,112 @@ var (
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query (read-only)
// ARGV[1] = slot TTL (seconds, retained for compatibility)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int) service.ConcurrencyCache {
// waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
@@ -155,6 +238,10 @@ func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
@@ -225,3 +312,75 @@ func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

View File

@@ -22,7 +22,7 @@ func BenchmarkAccountConcurrency(b *testing.B) {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes).(*concurrencyCache)
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {

View File

@@ -27,7 +27,7 @@ type ConcurrencyCacheSuite struct {
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes)
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
@@ -218,6 +218,48 @@ func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
@@ -232,6 +274,139 @@ func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -0,0 +1,32 @@
package repository
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}

View File

@@ -0,0 +1,69 @@
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/migrations"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
)
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
//
// 该函数执行以下操作:
// 1. 初始化全局时区设置,确保时间处理一致性
// 2. 建立 PostgreSQL 数据库连接
// 3. 自动执行数据库迁移,确保 schema 与代码同步
// 4. 创建并返回 Ent 客户端实例
//
// 重要提示:调用者必须负责关闭返回的 ent.Client关闭时会自动关闭底层的 driver/db
//
// 参数:
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
//
// 返回:
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
// - error: 初始化过程中的错误
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
// 这对于跨时区部署和日志时间戳的一致性至关重要。
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, nil, err
}
// 构建包含时区信息的数据库连接字符串 (DSN)。
// 时区信息会传递给 PostgreSQL确保数据库层面的时间处理正确。
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
drv, err := entsql.Open(dialect.Postgres, dsn)
if err != nil {
return nil, nil, err
}
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
return nil, nil, err
}
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
return client, drv.DB(), nil
}

View File

@@ -7,7 +7,7 @@ import (
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq"
)

View File

@@ -40,7 +40,6 @@ func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *servic
SetBalance(u.Balance).
SetConcurrency(u.Concurrency).
SetUsername(u.Username).
SetWechat(u.Wechat).
SetNotes(u.Notes)
if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt)

View File

@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err)
os.Exit(1)
}
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil {
if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err)
os.Exit(1)
}

View File

@@ -0,0 +1,206 @@
package repository
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/migrations"
)
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
// 该表用于跟踪已应用的迁移文件及其校验和。
// - filename: 迁移文件名,作为主键唯一标识每个迁移
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
// - applied_at: 迁移应用时间戳
const schemaMigrationsTableDDL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
checksum TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
// 该函数可以在每次应用启动时安全调用:
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
// - 如果迁移文件内容被修改checksum 不匹配),会返回错误
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
//
// 参数:
// - ctx: 上下文,用于超时控制和取消
// - db: 数据库连接
//
// 返回:
// - error: 迁移过程中的任何错误
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if db == nil {
return errors.New("nil sql db")
}
return applyMigrationsFS(ctx, db, migrations.FS)
}
// applyMigrationsFS 是迁移执行的核心实现。
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
//
// 迁移执行流程:
// 1. 获取 PostgreSQL Advisory Lock防止多实例并发迁移
// 2. 确保 schema_migrations 表存在
// 3. 按文件名排序读取所有 .sql 文件
// 4. 对于每个迁移文件:
// - 计算文件内容的 SHA256 校验和
// - 检查该迁移是否已应用(通过 filename 查询)
// - 如果已应用,验证校验和是否匹配
// - 如果未应用,在事务中执行迁移并记录
// 5. 释放 Advisory Lock
//
// 参数:
// - ctx: 上下文
// - db: 数据库连接
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if db == nil {
return errors.New("nil sql db")
}
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
if err := pgAdvisoryLock(ctx, db); err != nil {
return err
}
defer func() {
// 无论迁移是否成功,都要释放锁。
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
_ = pgAdvisoryUnlock(context.Background(), db)
}()
// 创建迁移记录表(如果不存在)。
// 该表记录所有已应用的迁移及其校验和。
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return fmt.Errorf("list migrations: %w", err)
}
sort.Strings(files) // 确保按文件名顺序执行迁移
for _, name := range files {
// 读取迁移文件内容
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return fmt.Errorf("read migration %s: %w", name, err)
}
content := strings.TrimSpace(string(contentBytes))
if content == "" {
continue // 跳过空文件
}
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
sum := sha256.Sum256([]byte(content))
checksum := hex.EncodeToString(sum[:])
// 检查该迁移是否已经应用
var existing string
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}
if !errors.Is(rowErr, sql.ErrNoRows) {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
}
// 执行迁移 SQL
if _, err := tx.ExecContext(ctx, content); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", name, err)
}
// 记录迁移已完成,保存文件名和校验和
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", name, err)
}
// 提交事务
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit migration %s: %w", name, err)
}
}
return nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
ticker := time.NewTicker(migrationsLockRetryInterval)
defer ticker.Stop()
for {
var locked bool
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
return fmt.Errorf("acquire migrations lock: %w", err)
}
if locked {
return nil
}
select {
case <-ctx.Done():
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
case <-ticker.C:
}
}
}
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("release migrations lock: %w", err)
}
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"database/sql"
"testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require"
)
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB))
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
@@ -24,7 +23,6 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "wechat", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields

View File

@@ -0,0 +1,39 @@
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}

View File

@@ -0,0 +1,385 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type userAttributeDefinitionRepository struct {
client *dbent.Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
return &userAttributeDefinitionRepository{client: client}
}
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
created, err := client.UserAttributeDefinition.Create().
SetKey(def.Key).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
}
def.ID = created.ID
def.DisplayOrder = created.DisplayOrder
def.CreatedAt = created.CreatedAt
def.UpdatedAt = created.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetDisplayOrder(def.DisplayOrder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
}
def.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeDefinition.Delete().
Where(userattributedefinition.IDEQ(id)).
Exec(ctx)
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
q := client.UserAttributeDefinition.Query()
if enabledOnly {
q = q.Where(userattributedefinition.EnabledEQ(true))
}
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeDefinition, 0, len(entities))
for _, e := range entities {
result = append(result, *defEntityToService(e))
}
return result, nil
}
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for id, order := range orders {
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
SetDisplayOrder(order).
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
}
return tx.Commit()
}
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Exist(ctx)
}
// UserAttributeValueRepository implementation
type userAttributeValueRepository struct {
client *dbent.Client
}
// NewUserAttributeValueRepository creates a new repository instance
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
return &userAttributeValueRepository{client: client}
}
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
if len(userIDs) == 0 {
return []service.UserAttributeValue{}, nil
}
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
if len(inputs) == 0 {
return nil
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, input := range inputs {
// Use upsert (ON CONFLICT DO UPDATE)
err := tx.UserAttributeValue.Create().
SetUserID(userID).
SetAttributeID(input.AttributeID).
SetValue(input.Value).
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
UpdateValue().
UpdateUpdatedAt().
Exec(ctx)
if err != nil {
return err
}
}
return tx.Commit()
}
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.AttributeIDEQ(attributeID)).
Exec(ctx)
return err
}
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.UserIDEQ(userID)).
Exec(ctx)
return err
}
// Helper functions for entity to service conversion
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
if e == nil {
return nil
}
return &service.UserAttributeDefinition{
ID: e.ID,
Key: e.Key,
Name: e.Name,
Description: e.Description,
Type: service.UserAttributeType(e.Type),
Options: toServiceOptions(e.Options),
Required: e.Required,
Validation: toServiceValidation(e.Validation),
Placeholder: e.Placeholder,
DisplayOrder: e.DisplayOrder,
Enabled: e.Enabled,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
}
// Type conversion helpers (map types <-> service types)
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
if opts == nil {
return []map[string]any{}
}
result := make([]map[string]any, len(opts))
for i, o := range opts {
result[i] = map[string]any{"value": o.Value, "label": o.Label}
}
return result
}
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
if opts == nil {
return []service.UserAttributeOption{}
}
result := make([]service.UserAttributeOption, len(opts))
for i, o := range opts {
result[i] = service.UserAttributeOption{
Value: getString(o, "value"),
Label: getString(o, "label"),
}
}
return result
}
func toEntValidation(v service.UserAttributeValidation) map[string]any {
result := map[string]any{}
if v.MinLength != nil {
result["min_length"] = *v.MinLength
}
if v.MaxLength != nil {
result["max_length"] = *v.MaxLength
}
if v.Min != nil {
result["min"] = *v.Min
}
if v.Max != nil {
result["max"] = *v.Max
}
if v.Pattern != nil {
result["pattern"] = *v.Pattern
}
if v.Message != nil {
result["message"] = *v.Message
}
return result
}
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
result := service.UserAttributeValidation{}
if val := getInt(v, "min_length"); val != nil {
result.MinLength = val
}
if val := getInt(v, "max_length"); val != nil {
result.MaxLength = val
}
if val := getInt(v, "min"); val != nil {
result.Min = val
}
if val := getInt(v, "max"); val != nil {
result.Max = val
}
if val := getStringPtr(v, "pattern"); val != nil {
result.Pattern = val
}
if val := getStringPtr(v, "message"); val != nil {
result.Message = val
}
return result
}
// Helper functions for type conversion
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getStringPtr(m map[string]any, key string) *string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return &s
}
}
return nil
}
func getInt(m map[string]any, key string) *int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return &n
case int64:
i := int(n)
return &i
case float64:
i := int(n)
return &i
}
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -50,7 +51,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
@@ -133,7 +133,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetWechat(userIn.Wechat).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
@@ -171,28 +170,38 @@ func (r *userRepository) Delete(ctx context.Context, id int64) error {
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
return r.ListWithFilters(ctx, params, service.UserListFilters{})
}
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]service.User, *pagination.PaginationResult, error) {
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query()
if status != "" {
q = q.Where(dbuser.StatusEQ(status))
if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(filters.Status))
}
if role != "" {
q = q.Where(dbuser.RoleEQ(role))
if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(filters.Role))
}
if search != "" {
if filters.Search != "" {
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(search),
dbuser.UsernameContainsFold(search),
dbuser.WechatContainsFold(search),
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
),
)
}
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
@@ -252,6 +261,59 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
if len(attrs) == 0 {
return nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
result = append(result, userID)
}
return result
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)

View File

@@ -166,7 +166,7 @@ func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, "", "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status)
@@ -176,7 +176,7 @@ func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.RoleAdmin, "")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role)
@@ -186,7 +186,7 @@ func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
@@ -196,22 +196,12 @@ func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
s.mustCreateUser(&service.User{Email: "w1@test.com", Wechat: "wx_hello"})
s.mustCreateUser(&service.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active")
@@ -226,7 +216,7 @@ func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
@@ -238,7 +228,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
@@ -246,7 +235,6 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
target := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
@@ -257,7 +245,7 @@ func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.StatusActive, service.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
@@ -448,7 +436,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
@@ -456,7 +443,6 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user2 := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
@@ -501,7 +487,7 @@ func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.StatusActive, service.RoleAdmin, "b@")
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")

View File

@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
}
// IncrementUsage 原子性地累加用量并校验限额
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。
// 当更新失败时,会执行额外查询确定具体超出的限额类型
// IncrementUsage 原子性地累加订阅用量。
// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 此处仅负责记录实际消费,确保消费数据的完整性
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加
// NULL 限额表示无限制
const atomicUpdateSQL = `
const updateSQL = `
UPDATE user_subscriptions us
SET
daily_usage_usd = us.daily_usage_usd + $1,
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
AND us.deleted_at IS NULL
AND us.group_id = g.id
AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
`
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id)
result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil {
return err
}
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
}
if affected > 0 {
return nil // 更新成功
return nil
}
// affected == 0可能是订阅不存在、分组已删除、或限额超出
// 执行额外查询确定具体原因
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
switch reason {
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
// affected == 0订阅不存在或已删除
return service.ErrSubscriptionNotFound
}
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {

View File

@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
// --- 限额检查与软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
s.T().Helper()
create := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeSubscription)
if daily != nil {
create.SetDailyLimitUsd(*daily)
}
if weekly != nil {
create.SetWeeklyLimitUsd(*weekly)
}
if monthly != nil {
create.SetMonthlyLimitUsd(*monthly)
}
g, err := create.Save(s.ctx)
s.Require().NoError(err, "create group with limits")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 先增加 9.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 2.0,会超过 10.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
s.Require().Error(err, "should fail when daily limit exceeded")
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
// 验证用量没有变化
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
weeklyLimit := 50.0
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 45.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 10.0,会超过 50.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().Error(err, "should fail when weekly limit exceeded")
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
monthlyLimit := 100.0
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 90.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 20.0,会超过 100.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
s.Require().Error(err, "should fail when monthly limit exceeded")
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 应该可以增加任意金额
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
s.Require().NoError(err, "should succeed without limits")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 正好达到限额应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().NoError(err, "should succeed at exact limit")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
}
// --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额
group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
dailyLimit := 5.0
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const numAttempts = 10
const incrementPerAttempt = 1.0
successCount := 0
for i := 0; i < numAttempts; i++ {
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
if err == nil {
successCount++
}
}
// 验证:应该有 5 次成功不超过限额5 次失败(超出限额)
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
// 验证最终用量等于限额
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())

View File

@@ -1,6 +1,11 @@
package repository
import (
"database/sql"
"errors"
entsql "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire"
@@ -10,7 +15,14 @@ import (
// ProvideConcurrencyCache 创建并发控制缓存,从配置读取 TTL 参数
// 性能优化TTL 可配置,支持长时间运行的 LLM 请求场景
func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.ConcurrencyCache {
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes)
waitTTLSeconds := int(cfg.Gateway.Scheduling.StickySessionWaitTimeout.Seconds())
if cfg.Gateway.Scheduling.FallbackWaitTimeout > cfg.Gateway.Scheduling.StickySessionWaitTimeout {
waitTTLSeconds = int(cfg.Gateway.Scheduling.FallbackWaitTimeout.Seconds())
}
if waitTTLSeconds <= 0 {
waitTTLSeconds = cfg.Gateway.ConcurrencySlotTTLMinutes * 60
}
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProviderSet is the Wire provider set for all repositories
@@ -24,6 +36,8 @@ var ProviderSet = wire.NewSet(
NewUsageLogRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
// Cache implementations
NewGatewayCache,
@@ -47,4 +61,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient,
NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}