feat: add affiliate invite rebate flow and admin rebate-rate setting
This commit is contained in:
420
backend/internal/repository/affiliate_repo.go
Normal file
420
backend/internal/repository/affiliate_repo.go
Normal file
@@ -0,0 +1,420 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
affiliateCodeLength = 12
|
||||
affiliateCodeMaxAttempts = 12
|
||||
)
|
||||
|
||||
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
||||
|
||||
type affiliateQueryExecer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type affiliateRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
|
||||
return &affiliateRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
|
||||
if userID <= 0 {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return ensureUserAffiliateWithClient(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return queryAffiliateByCode(ctx, client, code)
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
|
||||
var bound bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
|
||||
inviterID, userID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bind inviter: %w", err)
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
bound = false
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
|
||||
inviterID,
|
||||
); err != nil {
|
||||
return fmt.Errorf("increment inviter aff_count: %w", err)
|
||||
}
|
||||
bound = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bound, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
|
||||
if amount <= 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var applied bool
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
res, err := txClient.ExecContext(txCtx,
|
||||
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
|
||||
amount, inviterID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if affected == 0 {
|
||||
applied = false
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
||||
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
||||
}
|
||||
|
||||
applied = true
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
|
||||
var transferred float64
|
||||
var newBalance float64
|
||||
|
||||
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
||||
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := txClient.QueryContext(txCtx, `
|
||||
WITH claimed AS (
|
||||
SELECT aff_quota::double precision AS amount
|
||||
FROM user_affiliates
|
||||
WHERE user_id = $1
|
||||
AND aff_quota > 0
|
||||
FOR UPDATE
|
||||
),
|
||||
cleared AS (
|
||||
UPDATE user_affiliates ua
|
||||
SET aff_quota = 0,
|
||||
updated_at = NOW()
|
||||
FROM claimed c
|
||||
WHERE ua.user_id = $1
|
||||
RETURNING c.amount
|
||||
)
|
||||
SELECT amount
|
||||
FROM cleared`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("claim affiliate quota: %w", err)
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
_ = rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return service.ErrAffiliateQuotaEmpty
|
||||
}
|
||||
if err := rows.Scan(&transferred); err != nil {
|
||||
_ = rows.Close()
|
||||
return err
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if transferred <= 0 {
|
||||
return service.ErrAffiliateQuotaEmpty
|
||||
}
|
||||
|
||||
affected, err := txClient.User.Update().
|
||||
Where(user.IDEQ(userID)).
|
||||
AddBalance(transferred).
|
||||
AddTotalRecharged(transferred).
|
||||
Save(txCtx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
|
||||
newBalance, err = queryUserBalance(txCtx, txClient, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = txClient.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
||||
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
||||
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return transferred, newBalance, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
client := clientFromContext(ctx, r.client)
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT ua.user_id,
|
||||
COALESCE(u.email, ''),
|
||||
COALESCE(u.username, ''),
|
||||
ua.created_at
|
||||
FROM user_affiliates ua
|
||||
LEFT JOIN users u ON u.id = ua.user_id
|
||||
WHERE ua.inviter_id = $1
|
||||
ORDER BY ua.created_at DESC
|
||||
LIMIT $2`, inviterID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
invitees := make([]service.AffiliateInvitee, 0)
|
||||
for rows.Next() {
|
||||
var item service.AffiliateInvitee
|
||||
var createdAt time.Time
|
||||
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.CreatedAt = &createdAt
|
||||
invitees = append(invitees, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return invitees, nil
|
||||
}
|
||||
|
||||
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return fn(ctx, tx.Client())
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin affiliate transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
if err := fn(txCtx, tx.Client()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit affiliate transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
||||
summary, err := queryAffiliateByUserID(ctx, client, userID)
|
||||
if err == nil {
|
||||
return summary, nil
|
||||
}
|
||||
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := 0; i < affiliateCodeMaxAttempts; i++ {
|
||||
code, codeErr := generateAffiliateCode()
|
||||
if codeErr != nil {
|
||||
return nil, codeErr
|
||||
}
|
||||
_, insertErr := client.ExecContext(ctx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
|
||||
VALUES ($1, $2, NOW(), NOW())
|
||||
ON CONFLICT (user_id) DO NOTHING`, userID, code)
|
||||
if insertErr == nil {
|
||||
break
|
||||
}
|
||||
if isAffiliateUniqueViolation(insertErr) {
|
||||
continue
|
||||
}
|
||||
return nil, insertErr
|
||||
}
|
||||
|
||||
return queryAffiliateByUserID(ctx, client, userID)
|
||||
}
|
||||
|
||||
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM user_affiliates
|
||||
WHERE user_id = $1`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
|
||||
rows, err := client.QueryContext(ctx, `
|
||||
SELECT user_id,
|
||||
aff_code,
|
||||
inviter_id,
|
||||
aff_count,
|
||||
aff_quota::double precision,
|
||||
aff_history_quota::double precision,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM user_affiliates
|
||||
WHERE aff_code = $1
|
||||
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, service.ErrAffiliateProfileNotFound
|
||||
}
|
||||
|
||||
var out service.AffiliateSummary
|
||||
var inviterID sql.NullInt64
|
||||
if err := rows.Scan(
|
||||
&out.UserID,
|
||||
&out.AffCode,
|
||||
&inviterID,
|
||||
&out.AffCount,
|
||||
&out.AffQuota,
|
||||
&out.AffHistoryQuota,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inviterID.Valid {
|
||||
out.InviterID = &inviterID.Int64
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
|
||||
rows, err := client.QueryContext(ctx,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 0, service.ErrUserNotFound
|
||||
}
|
||||
var balance float64
|
||||
if err := rows.Scan(&balance); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func generateAffiliateCode() (string, error) {
|
||||
buf := make([]byte, affiliateCodeLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("generate affiliate code: %w", err)
|
||||
}
|
||||
for i := range buf {
|
||||
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func isAffiliateUniqueViolation(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) {
|
||||
return string(pqErr.Code) == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
114
backend/internal/repository/affiliate_repo_integration_test.go
Normal file
114
backend/internal/repository/affiliate_repo_integration_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
|
||||
t.Helper()
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
require.True(t, rows.Next(), "expected one row")
|
||||
var value float64
|
||||
require.NoError(t, rows.Scan(&value))
|
||||
require.NoError(t, rows.Err())
|
||||
return value
|
||||
}
|
||||
|
||||
func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
|
||||
t.Helper()
|
||||
rows, err := client.QueryContext(ctx, query, args...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
require.True(t, rows.Next(), "expected one row")
|
||||
var value int
|
||||
require.NoError(t, rows.Scan(&value))
|
||||
require.NoError(t, rows.Err())
|
||||
return value
|
||||
}
|
||||
|
||||
func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 5.5,
|
||||
Concurrency: 5,
|
||||
})
|
||||
|
||||
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
_, err := client.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
||||
require.NoError(t, err)
|
||||
|
||||
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 12.34, transferred, 1e-9)
|
||||
require.InDelta(t, 17.84, balance, 1e-9)
|
||||
|
||||
affQuota := querySingleFloat(t, txCtx, client,
|
||||
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
|
||||
require.InDelta(t, 0.0, affQuota, 1e-9)
|
||||
|
||||
persistedBalance := querySingleFloat(t, txCtx, client,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
|
||||
require.InDelta(t, 17.84, persistedBalance, 1e-9)
|
||||
|
||||
ledgerCount := querySingleInt(t, txCtx, client,
|
||||
"SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
|
||||
require.Equal(t, 1, ledgerCount)
|
||||
}
|
||||
|
||||
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
client := tx.Client()
|
||||
|
||||
repo := NewAffiliateRepository(client, integrationDB)
|
||||
|
||||
u := mustCreateUser(t, client, &service.User{
|
||||
Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
|
||||
PasswordHash: "hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 3.21,
|
||||
Concurrency: 5,
|
||||
})
|
||||
|
||||
affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
|
||||
_, err := client.ExecContext(txCtx, `
|
||||
INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
|
||||
VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
|
||||
require.NoError(t, err)
|
||||
|
||||
transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
|
||||
require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
|
||||
require.InDelta(t, 0.0, transferred, 1e-9)
|
||||
require.InDelta(t, 0.0, balance, 1e-9)
|
||||
|
||||
persistedBalance := querySingleFloat(t, txCtx, client,
|
||||
"SELECT balance::double precision FROM users WHERE id = $1", u.ID)
|
||||
require.InDelta(t, 3.21, persistedBalance, 1e-9)
|
||||
}
|
||||
@@ -91,6 +91,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewChannelRepository,
|
||||
NewChannelMonitorRepository,
|
||||
NewChannelMonitorRequestTemplateRepository,
|
||||
NewAffiliateRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
Reference in New Issue
Block a user