421 lines
10 KiB
Go
421 lines
10 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|
"github.com/Wei-Shaw/sub2api/ent/user"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
const (
|
|
affiliateCodeLength = 12
|
|
affiliateCodeMaxAttempts = 12
|
|
)
|
|
|
|
var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
|
|
|
|
type affiliateQueryExecer interface {
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
|
}
|
|
|
|
type affiliateRepository struct {
|
|
client *dbent.Client
|
|
}
|
|
|
|
func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
|
|
return &affiliateRepository{client: client}
|
|
}
|
|
|
|
func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
|
|
if userID <= 0 {
|
|
return nil, service.ErrUserNotFound
|
|
}
|
|
client := clientFromContext(ctx, r.client)
|
|
return ensureUserAffiliateWithClient(ctx, client, userID)
|
|
}
|
|
|
|
func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
|
|
client := clientFromContext(ctx, r.client)
|
|
return queryAffiliateByCode(ctx, client, code)
|
|
}
|
|
|
|
func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
|
|
var bound bool
|
|
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
|
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
|
return err
|
|
}
|
|
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err := txClient.ExecContext(txCtx,
|
|
"UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
|
|
inviterID, userID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("bind inviter: %w", err)
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
if affected == 0 {
|
|
bound = false
|
|
return nil
|
|
}
|
|
|
|
if _, err = txClient.ExecContext(txCtx,
|
|
"UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
|
|
inviterID,
|
|
); err != nil {
|
|
return fmt.Errorf("increment inviter aff_count: %w", err)
|
|
}
|
|
bound = true
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return bound, nil
|
|
}
|
|
|
|
func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64) (bool, error) {
|
|
if amount <= 0 {
|
|
return false, nil
|
|
}
|
|
|
|
var applied bool
|
|
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
|
res, err := txClient.ExecContext(txCtx,
|
|
"UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2",
|
|
amount, inviterID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, _ := res.RowsAffected()
|
|
if affected == 0 {
|
|
applied = false
|
|
return nil
|
|
}
|
|
|
|
if _, err = txClient.ExecContext(txCtx, `
|
|
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
|
VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
|
|
return fmt.Errorf("insert affiliate accrue ledger: %w", err)
|
|
}
|
|
|
|
applied = true
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return applied, nil
|
|
}
|
|
|
|
func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
|
|
var transferred float64
|
|
var newBalance float64
|
|
|
|
err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
|
|
if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := txClient.QueryContext(txCtx, `
|
|
WITH claimed AS (
|
|
SELECT aff_quota::double precision AS amount
|
|
FROM user_affiliates
|
|
WHERE user_id = $1
|
|
AND aff_quota > 0
|
|
FOR UPDATE
|
|
),
|
|
cleared AS (
|
|
UPDATE user_affiliates ua
|
|
SET aff_quota = 0,
|
|
updated_at = NOW()
|
|
FROM claimed c
|
|
WHERE ua.user_id = $1
|
|
RETURNING c.amount
|
|
)
|
|
SELECT amount
|
|
FROM cleared`, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("claim affiliate quota: %w", err)
|
|
}
|
|
|
|
if !rows.Next() {
|
|
_ = rows.Close()
|
|
if err := rows.Err(); err != nil {
|
|
return err
|
|
}
|
|
return service.ErrAffiliateQuotaEmpty
|
|
}
|
|
if err := rows.Scan(&transferred); err != nil {
|
|
_ = rows.Close()
|
|
return err
|
|
}
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
if transferred <= 0 {
|
|
return service.ErrAffiliateQuotaEmpty
|
|
}
|
|
|
|
affected, err := txClient.User.Update().
|
|
Where(user.IDEQ(userID)).
|
|
AddBalance(transferred).
|
|
AddTotalRecharged(transferred).
|
|
Save(txCtx)
|
|
if err != nil {
|
|
return fmt.Errorf("credit user balance by affiliate quota: %w", err)
|
|
}
|
|
if affected == 0 {
|
|
return service.ErrUserNotFound
|
|
}
|
|
|
|
newBalance, err = queryUserBalance(txCtx, txClient, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = txClient.ExecContext(txCtx, `
|
|
INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
|
|
VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
|
|
return fmt.Errorf("insert affiliate transfer ledger: %w", err)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return 0, 0, err
|
|
}
|
|
|
|
return transferred, newBalance, nil
|
|
}
|
|
|
|
func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
|
|
if limit <= 0 {
|
|
limit = 100
|
|
}
|
|
client := clientFromContext(ctx, r.client)
|
|
rows, err := client.QueryContext(ctx, `
|
|
SELECT ua.user_id,
|
|
COALESCE(u.email, ''),
|
|
COALESCE(u.username, ''),
|
|
ua.created_at
|
|
FROM user_affiliates ua
|
|
LEFT JOIN users u ON u.id = ua.user_id
|
|
WHERE ua.inviter_id = $1
|
|
ORDER BY ua.created_at DESC
|
|
LIMIT $2`, inviterID, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
invitees := make([]service.AffiliateInvitee, 0)
|
|
for rows.Next() {
|
|
var item service.AffiliateInvitee
|
|
var createdAt time.Time
|
|
if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt); err != nil {
|
|
return nil, err
|
|
}
|
|
item.CreatedAt = &createdAt
|
|
invitees = append(invitees, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return invitees, nil
|
|
}
|
|
|
|
func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
|
|
if tx := dbent.TxFromContext(ctx); tx != nil {
|
|
return fn(ctx, tx.Client())
|
|
}
|
|
|
|
tx, err := r.client.Tx(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("begin affiliate transaction: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
txCtx := dbent.NewTxContext(ctx, tx)
|
|
if err := fn(txCtx, tx.Client()); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit affiliate transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
|
summary, err := queryAffiliateByUserID(ctx, client, userID)
|
|
if err == nil {
|
|
return summary, nil
|
|
}
|
|
if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
|
|
return nil, err
|
|
}
|
|
|
|
for i := 0; i < affiliateCodeMaxAttempts; i++ {
|
|
code, codeErr := generateAffiliateCode()
|
|
if codeErr != nil {
|
|
return nil, codeErr
|
|
}
|
|
_, insertErr := client.ExecContext(ctx, `
|
|
INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
|
|
VALUES ($1, $2, NOW(), NOW())
|
|
ON CONFLICT (user_id) DO NOTHING`, userID, code)
|
|
if insertErr == nil {
|
|
break
|
|
}
|
|
if isAffiliateUniqueViolation(insertErr) {
|
|
continue
|
|
}
|
|
return nil, insertErr
|
|
}
|
|
|
|
return queryAffiliateByUserID(ctx, client, userID)
|
|
}
|
|
|
|
func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
|
|
rows, err := client.QueryContext(ctx, `
|
|
SELECT user_id,
|
|
aff_code,
|
|
inviter_id,
|
|
aff_count,
|
|
aff_quota::double precision,
|
|
aff_history_quota::double precision,
|
|
created_at,
|
|
updated_at
|
|
FROM user_affiliates
|
|
WHERE user_id = $1`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, service.ErrAffiliateProfileNotFound
|
|
}
|
|
|
|
var out service.AffiliateSummary
|
|
var inviterID sql.NullInt64
|
|
if err := rows.Scan(
|
|
&out.UserID,
|
|
&out.AffCode,
|
|
&inviterID,
|
|
&out.AffCount,
|
|
&out.AffQuota,
|
|
&out.AffHistoryQuota,
|
|
&out.CreatedAt,
|
|
&out.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
if inviterID.Valid {
|
|
out.InviterID = &inviterID.Int64
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
|
|
rows, err := client.QueryContext(ctx, `
|
|
SELECT user_id,
|
|
aff_code,
|
|
inviter_id,
|
|
aff_count,
|
|
aff_quota::double precision,
|
|
aff_history_quota::double precision,
|
|
created_at,
|
|
updated_at
|
|
FROM user_affiliates
|
|
WHERE aff_code = $1
|
|
LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, service.ErrAffiliateProfileNotFound
|
|
}
|
|
|
|
var out service.AffiliateSummary
|
|
var inviterID sql.NullInt64
|
|
if err := rows.Scan(
|
|
&out.UserID,
|
|
&out.AffCode,
|
|
&inviterID,
|
|
&out.AffCount,
|
|
&out.AffQuota,
|
|
&out.AffHistoryQuota,
|
|
&out.CreatedAt,
|
|
&out.UpdatedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
if inviterID.Valid {
|
|
out.InviterID = &inviterID.Int64
|
|
}
|
|
return &out, nil
|
|
}
|
|
|
|
func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
|
|
rows, err := client.QueryContext(ctx,
|
|
"SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
|
|
userID,
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
if !rows.Next() {
|
|
if err := rows.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
return 0, service.ErrUserNotFound
|
|
}
|
|
var balance float64
|
|
if err := rows.Scan(&balance); err != nil {
|
|
return 0, err
|
|
}
|
|
return balance, nil
|
|
}
|
|
|
|
func generateAffiliateCode() (string, error) {
|
|
buf := make([]byte, affiliateCodeLength)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", fmt.Errorf("generate affiliate code: %w", err)
|
|
}
|
|
for i := range buf {
|
|
buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
|
|
}
|
|
return string(buf), nil
|
|
}
|
|
|
|
func isAffiliateUniqueViolation(err error) bool {
|
|
var pqErr *pq.Error
|
|
if errors.As(err, &pqErr) {
|
|
return string(pqErr.Code) == "23505"
|
|
}
|
|
return false
|
|
}
|