diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 7eabde62..9386678d 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f8e0dcf4..d0b1d3af 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -71,7 +71,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) affiliateRepository := repository.NewAffiliateRepository(client, db) affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService) - authService := service.ProvideAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) + authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) @@ -82,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.ProvideUserHandler(userService, authService, emailService, emailCache, affiliateService) + userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -197,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index a4b7a297..ffe9ff5f 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars ( nil, nil, options.defaultSubAssigner, + nil, ) userSvc := service.NewUserService(userRepo, nil, nil, nil) var totpSvc *service.TotpService diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go index 1924cb81..f1c6d87d 100644 --- a/backend/internal/handler/auth_session_revocation_test.go +++ b/backend/internal/handler/auth_session_revocation_test.go @@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) { ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) handler := &AuthHandler{authService: authService} recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 7cf114c1..b3c7786d 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, nil, nil, nil, + nil, ) return &AuthHandler{ diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go index a7bc4ba3..377f432e 100644 --- a/backend/internal/handler/payment_handler_resume_test.go +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { Save(context.Background()) require.NoError(t, err) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing require.NoError(t, err) configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t require.NoError(t, err) configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() @@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) t.Cleanup(func() { _ = client.Close() }) - paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil) h := NewPaymentHandler(paymentSvc, nil, nil) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index c386792c..3f6ed8c2 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -5,7 +5,6 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -28,22 +27,17 @@ func NewUserHandler( authService *service.AuthService, emailService *service.EmailService, emailCache service.EmailCache, + affiliateService *service.AffiliateService, ) *UserHandler { return &UserHandler{ - userService: userService, - authService: authService, - emailService: emailService, - emailCache: emailCache, + userService: userService, + authService: authService, + emailService: emailService, + emailCache: emailCache, + affiliateService: affiliateService, } } -func (h *UserHandler) SetAffiliateService(affiliateService *service.AffiliateService) { - if h == nil { - return - } - h.affiliateService = affiliateService -} - // ChangePasswordRequest represents the change password request payload type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` @@ -168,13 +162,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { response.Success(c, profileResp) } -func (h *UserHandler) affiliateServiceOrErr() (*service.AffiliateService, error) { - if h == nil || h.affiliateService == nil { - return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") - } - return h.affiliateService, nil -} - // GetAffiliate returns the current user's affiliate details. // GET /api/v1/user/aff func (h *UserHandler) GetAffiliate(c *gin.Context) { @@ -184,13 +171,7 @@ func (h *UserHandler) GetAffiliate(c *gin.Context) { return } - affiliateSvc, err := h.affiliateServiceOrErr() - if err != nil { - response.ErrorFrom(c, err) - return - } - - detail, err := affiliateSvc.GetAffiliateDetail(c.Request.Context(), subject.UserID) + detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return @@ -207,13 +188,7 @@ func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) { return } - affiliateSvc, err := h.affiliateServiceOrErr() - if err != nil { - response.ErrorFrom(c, err) - return - } - - transferred, balance, err := affiliateSvc.TransferAffiliateQuota(c.Request.Context(), subject.UserID) + transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index a655b81c..8a864b51 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) recorder := httptest.NewRecorder() @@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) { }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`) recorder := httptest.NewRecorder() @@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) { }, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t * ExpireHour: 1, }, } - authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t }, } emailService := service.NewEmailService(nil, emailCache) - authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil) - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil) + authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil) body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`) recorder := httptest.NewRecorder() @@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { Status: service.StatusActive, }, } - handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil) + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil) body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`) recorder := httptest.NewRecorder() diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index d4b34fd2..6d175488 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -80,18 +80,6 @@ func ProvideSettingHandler(settingService *service.SettingService, buildInfo Bui return NewSettingHandler(settingService, buildInfo.Version) } -func ProvideUserHandler( - userService *service.UserService, - authService *service.AuthService, - emailService *service.EmailService, - emailCache service.EmailCache, - affiliateService *service.AffiliateService, -) *UserHandler { - handler := NewUserHandler(userService, authService, emailService, emailCache) - handler.SetAffiliateService(affiliateService) - return handler -} - // ProvideHandlers creates the Handlers struct func ProvideHandlers( authHandler *AuthHandler, @@ -137,7 +125,7 @@ func ProvideHandlers( var ProviderSet = wire.NewSet( // Top-level handlers NewAuthHandler, - ProvideUserHandler, + NewUserHandler, NewAPIKeyHandler, NewUsageHandler, NewRedeemHandler, diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go index 3ab5c0fb..3fa84426 100644 --- a/backend/internal/repository/affiliate_repo_integration_test.go +++ b/backend/internal/repository/affiliate_repo_integration_test.go @@ -80,6 +80,76 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34) require.Equal(t, 1, ledgerCount) } +// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the +// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx +// that already carries a transaction (via dbent.NewTxContext), repo.withTx +// must reuse that tx rather than opening a nested one. If this invariant +// breaks, AccrueQuota would commit independently and survive a rollback of +// the outer tx, which would violate payment_fulfillment's all-or-nothing +// semantics. +func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) { + ctx := context.Background() + + outerTx, err := integrationEntClient.Tx(ctx) + require.NoError(t, err, "begin outer tx") + // Defensive cleanup: if any require.* below fires before the explicit + // Rollback, this prevents the tx from leaking until container teardown. + // Rollback is idempotent at the driver level (extra rollback returns an + // error we ignore). + t.Cleanup(func() { _ = outerTx.Rollback() }) + client := outerTx.Client() + txCtx := dbent.NewTxContext(ctx, outerTx) + + inviter := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 5, + }) + invitee := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1), + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + Concurrency: 5, + }) + + repo := NewAffiliateRepository(client, integrationDB) + _, err = repo.EnsureUserAffiliate(txCtx, inviter.ID) + require.NoError(t, err) + _, err = repo.EnsureUserAffiliate(txCtx, invitee.ID) + require.NoError(t, err) + + bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID) + require.NoError(t, err) + require.True(t, bound, "invitee must bind to inviter") + + applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5) + require.NoError(t, err) + require.True(t, applied, "AccrueQuota must report applied=true") + + // Visible inside the outer tx. + innerQuota := querySingleFloat(t, txCtx, client, + "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID) + require.InDelta(t, 3.5, innerQuota, 1e-9) + + // Roll back the outer tx; if AccrueQuota had opened its own inner tx and + // committed it, the rows would still be visible to the global client. + require.NoError(t, outerTx.Rollback()) + + rows, err := integrationEntClient.QueryContext(ctx, + "SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)", + inviter.ID, invitee.ID) + require.NoError(t, err) + defer func() { _ = rows.Close() }() + require.True(t, rows.Next()) + var postRollbackCount int + require.NoError(t, rows.Scan(&postRollbackCount)) + require.Equal(t, 0, postRollbackCount, + "AccrueQuota must propagate the outer tx — found persisted rows after rollback") +} + func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) { ctx := context.Background() tx := testEntTx(t) diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 06e3355e..dde92dfd 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index 84fd6967..a643d3bc 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) @@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) { cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}} - authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil) toucher := &recordingActivityToucher{} diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go index 6fa5b423..fa8e2018 100644 --- a/backend/internal/service/affiliate_service.go +++ b/backend/internal/service/affiliate_service.go @@ -9,6 +9,7 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) var ( @@ -20,8 +21,32 @@ var ( const ( affiliateInviteesLimit = 100 + // affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength. + affiliateCodeFormatLength = 12 ) +// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used +// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9). +var affiliateCodeValidChar = func() [256]bool { + var tbl [256]bool + for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") { + tbl[c] = true + } + return tbl +}() + +func isValidAffiliateCodeFormat(code string) bool { + if len(code) != affiliateCodeFormatLength { + return false + } + for i := 0; i < len(code); i++ { + if !affiliateCodeValidChar[code[i]] { + return false + } + } + return true +} + type AffiliateSummary struct { UserID int64 `json:"user_id"` AffCode string `json:"aff_code"` @@ -110,6 +135,9 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, if code == "" { return nil } + if !isValidAffiliateCodeFormat(code) { + return ErrAffiliateCodeInvalid + } if s == nil || s.repo == nil { return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable") } @@ -279,10 +307,8 @@ func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } if s.billingCacheService != nil { - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) - }() + if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil { + logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err) + } } } diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go index 6adf879d..605fe00f 100644 --- a/backend/internal/service/affiliate_service_test.go +++ b/backend/internal/service/affiliate_service_test.go @@ -57,3 +57,35 @@ func TestMaskEmail(t *testing.T) { require.Equal(t, "x***@d***", maskEmail("x@domain")) require.Equal(t, "", maskEmail("")) } + +func TestIsValidAffiliateCodeFormat(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + in string + want bool + }{ + {"valid canonical", "ABCDEFGHJKLM", true}, + {"valid all digits 2-9", "234567892345", true}, + {"valid mixed", "A2B3C4D5E6F7", true}, + {"too short", "ABCDEFGHJKL", false}, + {"too long", "ABCDEFGHJKLMN", false}, + {"contains excluded letter I", "IBCDEFGHJKLM", false}, + {"contains excluded letter O", "OBCDEFGHJKLM", false}, + {"contains excluded digit 0", "0BCDEFGHJKLM", false}, + {"contains excluded digit 1", "1BCDEFGHJKLM", false}, + {"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false}, + {"empty", "", false}, + {"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset + {"ascii punctuation", "ABCDEFGHJK.M", false}, + {"whitespace", "ABCDEFGHJK M", false}, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in)) + }) + } +} diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index e3fb2f85..21d9d6e9 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService( nil, nil, nil, + nil, ) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fe0c32f5..08b0f4b7 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -99,6 +99,7 @@ func NewAuthService( emailQueueService *EmailQueueService, promoService *PromoService, defaultSubAssigner DefaultSubscriptionAssigner, + affiliateService *AffiliateService, ) *AuthService { return &AuthService{ entClient: entClient, @@ -111,6 +112,7 @@ func NewAuthService( turnstileService: turnstileService, emailQueueService: emailQueueService, promoService: promoService, + affiliateService: affiliateService, defaultSubAssigner: defaultSubAssigner, } } @@ -122,26 +124,13 @@ func (s *AuthService) EntClient() *dbent.Client { return s.entClient } -func (s *AuthService) SetAffiliateService(affiliateService *AffiliateService) { - if s == nil { - return - } - s.affiliateService = affiliateService -} - // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { - return s.RegisterWithVerification(ctx, email, password, "", "", "") + return s.RegisterWithVerification(ctx, email, password, "", "", "", "") } // RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。 -// affiliateCode 使用可选参数以兼容旧调用方。 -func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string, affiliateCode ...string) (string, *User, error) { - affiliateCodeRaw := "" - if len(affiliateCode) > 0 { - affiliateCodeRaw = affiliateCode[0] - } - +func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) { // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled @@ -241,7 +230,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err) } - if code := strings.TrimSpace(affiliateCodeRaw); code != "" { + if code := strings.TrimSpace(affiliateCode); code != "" { if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil { // 邀请返利码绑定失败不影响注册,只记录日志 logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err) diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index cced842a..ea2308f7 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( emailSvc = service.NewEmailService(settingRepo, emailCache) } - svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner) + svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil) return svc, repo, client } @@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t }, } emailService := service.NewEmailService(nil, cache) - svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil) + svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil) oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{ ID: 41, diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 2233e427..53048b92 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( values: settings, }, cfg) - svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner) + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil) return svc, repo, client } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index dbd18a20..c1ad6240 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, // promoService nil, // defaultSubAssigner + nil, // affiliateService ) } @@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi }, nil) // 应返回服务不可用错误,而不是允许绕过验证 - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "") require.ErrorIs(t, err, ErrServiceUnavailable) } @@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "") require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorContains(t, err, "verify code") } diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 477ba1b2..3512822f 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier nil, // emailQueueService nil, // promoService nil, // defaultSubAssigner + nil, // affiliateService ) } diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 15f6feeb..aa121e41 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -184,19 +184,12 @@ type PaymentService struct { affiliateService *AffiliateService } -func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} +func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService { + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService} svc.resumeService = psNewPaymentResumeService(configService) return svc } -func (s *PaymentService) SetAffiliateService(affiliateService *AffiliateService) { - if s == nil { - return - } - s.affiliateService = affiliateService -} - // --- Provider Registry --- // EnsureProviders lazily initializes the provider registry on first call. diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index d8a6a332..b1d9aaba 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -391,53 +391,6 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit return svc } -func ProvideAuthService( - entClient *dbent.Client, - userRepo UserRepository, - redeemRepo RedeemCodeRepository, - refreshTokenCache RefreshTokenCache, - cfg *config.Config, - settingService *SettingService, - emailService *EmailService, - turnstileService *TurnstileService, - emailQueueService *EmailQueueService, - promoService *PromoService, - defaultSubAssigner DefaultSubscriptionAssigner, - affiliateService *AffiliateService, -) *AuthService { - svc := NewAuthService( - entClient, - userRepo, - redeemRepo, - refreshTokenCache, - cfg, - settingService, - emailService, - turnstileService, - emailQueueService, - promoService, - defaultSubAssigner, - ) - svc.SetAffiliateService(affiliateService) - return svc -} - -func ProvidePaymentService( - entClient *dbent.Client, - registry *payment.Registry, - loadBalancer payment.LoadBalancer, - redeemService *RedeemService, - subscriptionSvc *SubscriptionService, - configService *PaymentConfigService, - userRepo UserRepository, - groupRepo GroupRepository, - affiliateService *AffiliateService, -) *PaymentService { - svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo) - svc.SetAffiliateService(affiliateService) - return svc -} - // ProvideBillingCacheService wires BillingCacheService with its RPM dependencies. func ProvideBillingCacheService( cache BillingCache, @@ -454,7 +407,7 @@ func ProvideBillingCacheService( // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services - ProvideAuthService, + NewAuthService, NewUserService, NewAPIKeyService, ProvideAPIKeyAuthCacheInvalidator, @@ -535,7 +488,7 @@ var ProviderSet = wire.NewSet( NewModelPricingResolver, NewAffiliateService, ProvidePaymentConfigService, - ProvidePaymentService, + NewPaymentService, ProvidePaymentOrderExpiryService, ProvideBalanceNotifyService, ProvideChannelMonitorService,