refactor(affiliate): tighten DI and harden inviter code validation

- Drop SetAffiliateService setters and ProvideAuthService /
  ProvidePaymentService / ProvideUserHandler wrappers in favor of direct
  Wire constructor injection. AffiliateService has no back-edge to
  Auth/Payment/User, so the indirection was never required.
- Change RegisterWithVerification's variadic affiliateCode to a fixed
  parameter; adjust all call sites.
- Validate aff_code length and charset in BindInviterByCode before any
  DB lookup, eliminating timing-side-channel and useless DB roundtrips
  on malformed input.
- Make affiliate cache invalidation synchronous; surface Redis errors
  via the project logger instead of swallowing them in a detached
  goroutine.
- Add an integration test guarding cross-layer tx propagation in
  AccrueQuota and a unit test pinning the aff_code format rules.
This commit is contained in:
shaw
2026-04-25 08:44:18 +08:00
parent 5b5db88550
commit aa8ee33b0a
22 changed files with 188 additions and 157 deletions

View File

@@ -9,6 +9,7 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
var (
@@ -20,8 +21,32 @@ var (
const (
affiliateInviteesLimit = 100
// affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength.
affiliateCodeFormatLength = 12
)
// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used
// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9).
var affiliateCodeValidChar = func() [256]bool {
var tbl [256]bool
for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") {
tbl[c] = true
}
return tbl
}()
func isValidAffiliateCodeFormat(code string) bool {
if len(code) != affiliateCodeFormatLength {
return false
}
for i := 0; i < len(code); i++ {
if !affiliateCodeValidChar[code[i]] {
return false
}
}
return true
}
type AffiliateSummary struct {
UserID int64 `json:"user_id"`
AffCode string `json:"aff_code"`
@@ -110,6 +135,9 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
if code == "" {
return nil
}
if !isValidAffiliateCodeFormat(code) {
return ErrAffiliateCodeInvalid
}
if s == nil || s.repo == nil {
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
}
@@ -279,10 +307,8 @@ func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
}
}
}

View File

@@ -57,3 +57,35 @@ func TestMaskEmail(t *testing.T) {
require.Equal(t, "x***@d***", maskEmail("x@domain"))
require.Equal(t, "", maskEmail(""))
}
func TestIsValidAffiliateCodeFormat(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
want bool
}{
{"valid canonical", "ABCDEFGHJKLM", true},
{"valid all digits 2-9", "234567892345", true},
{"valid mixed", "A2B3C4D5E6F7", true},
{"too short", "ABCDEFGHJKL", false},
{"too long", "ABCDEFGHJKLMN", false},
{"contains excluded letter I", "IBCDEFGHJKLM", false},
{"contains excluded letter O", "OBCDEFGHJKLM", false},
{"contains excluded digit 0", "0BCDEFGHJKLM", false},
{"contains excluded digit 1", "1BCDEFGHJKLM", false},
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
{"empty", "", false},
{"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset
{"ascii punctuation", "ABCDEFGHJK.M", false},
{"whitespace", "ABCDEFGHJK M", false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
})
}
}

View File

@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
nil,
nil,
nil,
nil,
)
}

View File

@@ -99,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -111,6 +112,7 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
@@ -122,26 +124,13 @@ func (s *AuthService) EntClient() *dbent.Client {
return s.entClient
}
func (s *AuthService) SetAffiliateService(affiliateService *AffiliateService) {
if s == nil {
return
}
s.affiliateService = affiliateService
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
}
// RegisterWithVerification 用户注册支持邮件验证、优惠码、邀请码和邀请返利码返回token和用户。
// affiliateCode 使用可选参数以兼容旧调用方。
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string, affiliateCode ...string) (string, *User, error) {
affiliateCodeRaw := ""
if len(affiliateCode) > 0 {
affiliateCodeRaw = affiliateCode[0]
}
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -241,7 +230,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
}
if code := strings.TrimSpace(affiliateCodeRaw); code != "" {
if code := strings.TrimSpace(affiliateCode); code != "" {
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
// 邀请返利码绑定失败不影响注册,只记录日志
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)

View File

@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,

View File

@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
values: settings,
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
return svc, repo, client
}

View File

@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}
@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}

View File

@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
nil, // affiliateService
)
}

View File

@@ -184,19 +184,12 @@ type PaymentService struct {
affiliateService *AffiliateService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
svc.resumeService = psNewPaymentResumeService(configService)
return svc
}
func (s *PaymentService) SetAffiliateService(affiliateService *AffiliateService) {
if s == nil {
return
}
s.affiliateService = affiliateService
}
// --- Provider Registry ---
// EnsureProviders lazily initializes the provider registry on first call.

View File

@@ -391,53 +391,6 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
return svc
}
func ProvideAuthService(
entClient *dbent.Client,
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
affiliateService *AffiliateService,
) *AuthService {
svc := NewAuthService(
entClient,
userRepo,
redeemRepo,
refreshTokenCache,
cfg,
settingService,
emailService,
turnstileService,
emailQueueService,
promoService,
defaultSubAssigner,
)
svc.SetAffiliateService(affiliateService)
return svc
}
func ProvidePaymentService(
entClient *dbent.Client,
registry *payment.Registry,
loadBalancer payment.LoadBalancer,
redeemService *RedeemService,
subscriptionSvc *SubscriptionService,
configService *PaymentConfigService,
userRepo UserRepository,
groupRepo GroupRepository,
affiliateService *AffiliateService,
) *PaymentService {
svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo)
svc.SetAffiliateService(affiliateService)
return svc
}
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
func ProvideBillingCacheService(
cache BillingCache,
@@ -454,7 +407,7 @@ func ProvideBillingCacheService(
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
ProvideAuthService,
NewAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
@@ -535,7 +488,7 @@ var ProviderSet = wire.NewSet(
NewModelPricingResolver,
NewAffiliateService,
ProvidePaymentConfigService,
ProvidePaymentService,
NewPaymentService,
ProvidePaymentOrderExpiryService,
ProvideBalanceNotifyService,
ProvideChannelMonitorService,