178 lines
5.0 KiB
Go
178 lines
5.0 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
"github.com/lib/pq"
|
||
)
|
||
|
||
type userGroupRateRepository struct {
|
||
sql sqlExecutor
|
||
}
|
||
|
||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||
return &userGroupRateRepository{sql: sqlDB}
|
||
}
|
||
|
||
// GetByUserID 获取用户的所有专属分组倍率
|
||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() { _ = rows.Close() }()
|
||
|
||
result := make(map[int64]float64)
|
||
for rows.Next() {
|
||
var groupID int64
|
||
var rate float64
|
||
if err := rows.Scan(&groupID, &rate); err != nil {
|
||
return nil, err
|
||
}
|
||
result[groupID] = rate
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// GetByUserIDs 批量获取多个用户的专属分组倍率。
|
||
// 返回结构:map[userID]map[groupID]rate
|
||
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
|
||
result := make(map[int64]map[int64]float64, len(userIDs))
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
uniqueIDs := make([]int64, 0, len(userIDs))
|
||
seen := make(map[int64]struct{}, len(userIDs))
|
||
for _, userID := range userIDs {
|
||
if userID <= 0 {
|
||
continue
|
||
}
|
||
if _, exists := seen[userID]; exists {
|
||
continue
|
||
}
|
||
seen[userID] = struct{}{}
|
||
uniqueIDs = append(uniqueIDs, userID)
|
||
result[userID] = make(map[int64]float64)
|
||
}
|
||
if len(uniqueIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
rows, err := r.sql.QueryContext(ctx, `
|
||
SELECT user_id, group_id, rate_multiplier
|
||
FROM user_group_rate_multipliers
|
||
WHERE user_id = ANY($1)
|
||
`, pq.Array(uniqueIDs))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer func() { _ = rows.Close() }()
|
||
|
||
for rows.Next() {
|
||
var userID int64
|
||
var groupID int64
|
||
var rate float64
|
||
if err := rows.Scan(&userID, &groupID, &rate); err != nil {
|
||
return nil, err
|
||
}
|
||
if _, ok := result[userID]; !ok {
|
||
result[userID] = make(map[int64]float64)
|
||
}
|
||
result[userID][groupID] = rate
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||
var rate float64
|
||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||
if err == sql.ErrNoRows {
|
||
return nil, nil
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &rate, nil
|
||
}
|
||
|
||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||
if len(rates) == 0 {
|
||
// 如果传入空 map,删除该用户的所有专属倍率
|
||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||
return err
|
||
}
|
||
|
||
// 分离需要删除和需要 upsert 的记录
|
||
var toDelete []int64
|
||
upsertGroupIDs := make([]int64, 0, len(rates))
|
||
upsertRates := make([]float64, 0, len(rates))
|
||
for groupID, rate := range rates {
|
||
if rate == nil {
|
||
toDelete = append(toDelete, groupID)
|
||
} else {
|
||
upsertGroupIDs = append(upsertGroupIDs, groupID)
|
||
upsertRates = append(upsertRates, *rate)
|
||
}
|
||
}
|
||
|
||
// 删除指定的记录
|
||
if len(toDelete) > 0 {
|
||
if _, err := r.sql.ExecContext(ctx,
|
||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
|
||
userID, pq.Array(toDelete)); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// Upsert 记录
|
||
now := time.Now()
|
||
if len(upsertGroupIDs) > 0 {
|
||
_, err := r.sql.ExecContext(ctx, `
|
||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||
SELECT
|
||
$1::bigint,
|
||
data.group_id,
|
||
data.rate_multiplier,
|
||
$2::timestamptz,
|
||
$2::timestamptz
|
||
FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
|
||
ON CONFLICT (user_id, group_id)
|
||
DO UPDATE SET
|
||
rate_multiplier = EXCLUDED.rate_multiplier,
|
||
updated_at = EXCLUDED.updated_at
|
||
`, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||
return err
|
||
}
|
||
|
||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||
return err
|
||
}
|