Files
sub2api/backend/internal/repository/affiliate_repo.go
shaw 4e1bb2b445 feat(affiliate): add feature toggle and per-user custom invite settings
- 在系统设置「功能开关」中新增邀请返利总开关,默认关闭;
  关闭态:菜单隐藏、注册忽略 aff、新充值不返利,但已有 quota 仍可转余额
- 支持管理员为指定用户设置专属邀请码(覆盖随机码,全局唯一)
- 支持管理员为指定用户设置专属返利比例(覆盖全局比例,可单条/批量调整)
- 在系统设置邀请返利卡片内嵌入专属用户管理表格(搜索/编辑/批量/删除),
  删除采用项目通用 ConfirmDialog,会同时清除专属比例并把邀请码重置为系统随机码
- /affiliate 用户页新增「我的返利比例」卡片与动态使用说明,让用户直观看到
  分享后能拿到多少(同源 resolveRebateRatePercent 计算,与实际充值一致)
- 新增数据库迁移 132 添加 aff_rebate_rate_percent 与 aff_code_custom 列
- 新增 admin 路由组 /api/v1/admin/affiliates/users/* 共 5 个端点
- AffiliateService 改为只依赖 *SettingService,去除冗余的 SettingRepository
- 邀请码格式校验放宽到 [A-Z0-9_-]{4,32},兼容旧 12 位系统码与新自定义码
- 补充单元测试与集成测试覆盖新方法、冲突路径与边界值
2026-04-25 20:22:07 +08:00

665 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
const (
affiliateCodeLength = 12
affiliateCodeMaxAttempts = 12
)
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
type affiliateQueryExecer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type affiliateRepository struct {
client *dbent.Client
}
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
return &affiliateRepository{client: client}
}
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
if userID <= 0 {
return nil, service.ErrUserNotFound
}
client := clientFromContext(ctx, r.client)
return ensureUserAffiliateWithClient(ctx, client, userID)
}
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
client := clientFromContext(ctx, r.client)
return queryAffiliateByCode(ctx, client, code)
}
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
var bound bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
return err
}
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
inviterID, userID,
)
if err != nil {
return fmt.Errorf("bind inviter: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
bound = false
return nil
}
if _, err = txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
inviterID,
); err != nil {
return fmt.Errorf("increment inviter aff_count: %w", err)
}
bound = true
return nil
})
if err != nil {
return false, err
}
return bound, nil
}
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
if amount <= 0 {
return false, nil
}
var applied bool
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
res, err := txClient.ExecContext(txCtx,
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
amount, inviterID,
)
if err != nil {
return err
}
affected, _ := res.RowsAffected()
if affected == 0 {
applied = false
return nil
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
}
applied = true
return nil
})
if err != nil {
return false, err
}
return applied, nil
}
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
var transferred float64
var newBalance float64
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
rows, err := txClient.QueryContext(txCtx, `
WITH claimed AS (
SELECT aff_quota::double precision AS amount
FROM user_affiliates
WHERE user_id = $1
AND aff_quota > 0
FOR UPDATE
),
cleared AS (
UPDATE user_affiliates ua
SET aff_quota = 0,
updated_at = NOW()
FROM claimed c
WHERE ua.user_id = $1
RETURNING c.amount
)
SELECT amount
FROM cleared`, userID)
if err != nil {
return fmt.Errorf("claim affiliate quota: %w", err)
}
if !rows.Next() {
_ = rows.Close()
if err := rows.Err(); err != nil {
return err
}
return service.ErrAffiliateQuotaEmpty
}
if err := rows.Scan(&transferred); err != nil {
_ = rows.Close()
return err
}
if err := rows.Close(); err != nil {
return err
}
if transferred <= 0 {
return service.ErrAffiliateQuotaEmpty
}
affected, err := txClient.User.Update().
Where(user.IDEQ(userID)).
AddBalance(transferred).
AddTotalRecharged(transferred).
Save(txCtx)
if err != nil {
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
}
if affected == 0 {
return service.ErrUserNotFound
}
newBalance, err = queryUserBalance(txCtx, txClient, userID)
if err != nil {
return err
}
if _, err = txClient.ExecContext(txCtx, `
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
}
return nil
})
if err != nil {
return 0, 0, err
}
return transferred, newBalance, nil
}
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
if limit <= 0 {
limit = 100
}
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.created_at
FROM user_affiliates ua
LEFT JOIN users u ON u.id = ua.user_id
WHERE ua.inviter_id = $1
ORDER BY ua.created_at DESC
LIMIT $2`, inviterID, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
invitees := make([]service.AffiliateInvitee, 0)
for rows.Next() {
var item service.AffiliateInvitee
var createdAt time.Time
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
return nil, err
}
item.CreatedAt = &createdAt
invitees = append(invitees, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return invitees, nil
}
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return fn(ctx, tx.Client())
}
tx, err := r.client.Tx(ctx)
if err != nil {
return fmt.Errorf("begin affiliate transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx, tx.Client()); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit affiliate transaction: %w", err)
}
return nil
}
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
summary, err := queryAffiliateByUserID(ctx, client, userID)
if err == nil {
return summary, nil
}
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
return nil, err
}
for i := 0; i < affiliateCodeMaxAttempts; i++ {
code, codeErr := generateAffiliateCode()
if codeErr != nil {
return nil, codeErr
}
_, insertErr := client.ExecContext(ctx, `
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING`, userID, code)
if insertErr == nil {
break
}
if isAffiliateUniqueViolation(insertErr) {
continue
}
return nil, insertErr
}
return queryAffiliateByUserID(ctx, client, userID)
}
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
aff_code_custom,
aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&out.AffCodeCustom,
&rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
if rebateRate.Valid {
v := rebateRate.Float64
out.AffRebateRatePercent = &v
}
return &out, nil
}
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
rows, err := client.QueryContext(ctx, `
SELECT user_id,
aff_code,
aff_code_custom,
aff_rebate_rate_percent,
inviter_id,
aff_count,
aff_quota::double precision,
aff_history_quota::double precision,
created_at,
updated_at
FROM user_affiliates
WHERE aff_code = $1
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return nil, err
}
return nil, service.ErrAffiliateProfileNotFound
}
var out service.AffiliateSummary
var inviterID sql.NullInt64
var rebateRate sql.NullFloat64
if err := rows.Scan(
&out.UserID,
&out.AffCode,
&out.AffCodeCustom,
&rebateRate,
&inviterID,
&out.AffCount,
&out.AffQuota,
&out.AffHistoryQuota,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if inviterID.Valid {
out.InviterID = &inviterID.Int64
}
if rebateRate.Valid {
v := rebateRate.Float64
out.AffRebateRatePercent = &v
}
return &out, nil
}
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
rows, err := client.QueryContext(ctx,
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
userID,
)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return 0, err
}
return 0, service.ErrUserNotFound
}
var balance float64
if err := rows.Scan(&balance); err != nil {
return 0, err
}
return balance, nil
}
func generateAffiliateCode() (string, error) {
buf := make([]byte, affiliateCodeLength)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("generate affiliate code: %w", err)
}
for i := range buf {
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
}
return string(buf), nil
}
func isAffiliateUniqueViolation(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) {
return string(pqErr.Code) == "23505"
}
return false
}
// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
// 唯一性冲突返回 ErrAffiliateCodeTaken。
func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
if userID <= 0 {
return service.ErrUserNotFound
}
code := strings.ToUpper(strings.TrimSpace(newCode))
if code == "" {
return service.ErrAffiliateCodeInvalid
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_code = $1,
aff_code_custom = true,
updated_at = NOW()
WHERE user_id = $2`, code, userID)
if err != nil {
if isAffiliateUniqueViolation(err) {
return service.ErrAffiliateCodeTaken
}
return fmt.Errorf("update aff_code: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
return nil
})
}
// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
if userID <= 0 {
return "", service.ErrUserNotFound
}
var newCode string
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
for i := 0; i < affiliateCodeMaxAttempts; i++ {
candidate, codeErr := generateAffiliateCode()
if codeErr != nil {
return codeErr
}
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_code = $1,
aff_code_custom = false,
updated_at = NOW()
WHERE user_id = $2`, candidate, userID)
if err != nil {
if isAffiliateUniqueViolation(err) {
continue
}
return fmt.Errorf("reset aff_code: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
newCode = candidate
return nil
}
return fmt.Errorf("reset aff_code: exhausted attempts")
})
if err != nil {
return "", err
}
return newCode, nil
}
// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
if userID <= 0 {
return service.ErrUserNotFound
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
return err
}
// nullableArg lets us use a single UPDATE for both "set value" and
// "clear" cases — database/sql converts nil interface{} to SQL NULL.
res, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_rebate_rate_percent = $1,
updated_at = NOW()
WHERE user_id = $2`, nullableArg(ratePercent), userID)
if err != nil {
return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
}
affected, _ := res.RowsAffected()
if affected == 0 {
return service.ErrUserNotFound
}
return nil
})
}
// BatchSetUserRebateRate 批量为多个用户设置专属比例nil 清除)。
func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
if len(userIDs) == 0 {
return nil
}
return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
for _, uid := range userIDs {
if uid <= 0 {
continue
}
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
return err
}
}
_, err := txClient.ExecContext(txCtx, `
UPDATE user_affiliates
SET aff_rebate_rate_percent = $1,
updated_at = NOW()
WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
if err != nil {
return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
}
return nil
})
}
// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
// binding: nil pointer → SQL NULL, non-nil → the float value.
func nullableArg(v *float64) any {
if v == nil {
return nil
}
return *v
}
// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
//
// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索"
// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
// 这避免了为两种情况维护两份 SQL 模板。
func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
page := filter.Page
if page < 1 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 || pageSize > 200 {
pageSize = 20
}
offset := (page - 1) * pageSize
likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
const baseFrom = `
FROM user_affiliates ua
JOIN users u ON u.id = ua.user_id
WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
AND (u.email ILIKE $1 OR u.username ILIKE $1)`
client := clientFromContext(ctx, r.client)
total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
if err != nil {
return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
}
listQuery := `
SELECT ua.user_id,
COALESCE(u.email, ''),
COALESCE(u.username, ''),
ua.aff_code,
ua.aff_code_custom,
ua.aff_rebate_rate_percent,
ua.aff_count` + baseFrom + `
ORDER BY ua.updated_at DESC
LIMIT $2 OFFSET $3`
rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
if err != nil {
return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
}
defer func() { _ = rows.Close() }()
entries := make([]service.AffiliateAdminEntry, 0)
for rows.Next() {
var e service.AffiliateAdminEntry
var rebate sql.NullFloat64
if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
&e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
return nil, 0, err
}
if rebate.Valid {
v := rebate.Float64
e.AffRebateRatePercent = &v
}
entries = append(entries, e)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return entries, total, nil
}
// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
rows, err := client.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
if err := rows.Err(); err != nil {
return 0, err
}
return 0, nil
}
var v int64
if err := rows.Scan(&v); err != nil {
return 0, err
}
return v, nil
}