refactor(affiliate): tighten DI and harden inviter code validation
- Drop SetAffiliateService setters and ProvideAuthService / ProvidePaymentService / ProvideUserHandler wrappers in favor of direct Wire constructor injection. AffiliateService has no back-edge to Auth/Payment/User, so the indirection was never required. - Change RegisterWithVerification's variadic affiliateCode to a fixed parameter; adjust all call sites. - Validate aff_code length and charset in BindInviterByCode before any DB lookup, eliminating timing-side-channel and useless DB roundtrips on malformed input. - Make affiliate cache invalidation synchronous; surface Redis errors via the project logger instead of swallowing them in a detached goroutine. - Add an integration test guarding cross-layer tx propagation in AccrueQuota and a unit test pinning the aff_code format rules.
This commit is contained in:
@@ -33,7 +33,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
userRepo := repository.NewUserRepository(client, sqlDB)
|
userRepo := repository.NewUserRepository(client, sqlDB)
|
||||||
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||||
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
affiliateRepository := repository.NewAffiliateRepository(client, db)
|
||||||
affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
|
affiliateService := service.NewAffiliateService(affiliateRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCacheService)
|
||||||
authService := service.ProvideAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
|
||||||
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
@@ -82,7 +82,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
totpCache := repository.NewTotpCache(redisClient)
|
totpCache := repository.NewTotpCache(redisClient)
|
||||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||||
userHandler := handler.ProvideUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
@@ -197,7 +197,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||||
registry := payment.ProvideRegistry()
|
registry := payment.ProvideRegistry()
|
||||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||||
paymentService := service.ProvidePaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
|
|||||||
@@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
options.defaultSubAssigner,
|
options.defaultSubAssigner,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
var totpSvc *service.TotpService
|
var totpSvc *service.TotpService
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
|
|||||||
ExpireHour: 1,
|
ExpireHour: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
handler := &AuthHandler{authService: authService}
|
handler := &AuthHandler{authService: authService}
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &AuthHandler{
|
return &AuthHandler{
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
|||||||
Save(context.Background())
|
Save(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
|||||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
t.Cleanup(func() { _ = client.Close() })
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
|
||||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -28,22 +27,17 @@ func NewUserHandler(
|
|||||||
authService *service.AuthService,
|
authService *service.AuthService,
|
||||||
emailService *service.EmailService,
|
emailService *service.EmailService,
|
||||||
emailCache service.EmailCache,
|
emailCache service.EmailCache,
|
||||||
|
affiliateService *service.AffiliateService,
|
||||||
) *UserHandler {
|
) *UserHandler {
|
||||||
return &UserHandler{
|
return &UserHandler{
|
||||||
userService: userService,
|
userService: userService,
|
||||||
authService: authService,
|
authService: authService,
|
||||||
emailService: emailService,
|
emailService: emailService,
|
||||||
emailCache: emailCache,
|
emailCache: emailCache,
|
||||||
|
affiliateService: affiliateService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) SetAffiliateService(affiliateService *service.AffiliateService) {
|
|
||||||
if h == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.affiliateService = affiliateService
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangePasswordRequest represents the change password request payload
|
// ChangePasswordRequest represents the change password request payload
|
||||||
type ChangePasswordRequest struct {
|
type ChangePasswordRequest struct {
|
||||||
OldPassword string `json:"old_password" binding:"required"`
|
OldPassword string `json:"old_password" binding:"required"`
|
||||||
@@ -168,13 +162,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
|||||||
response.Success(c, profileResp)
|
response.Success(c, profileResp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *UserHandler) affiliateServiceOrErr() (*service.AffiliateService, error) {
|
|
||||||
if h == nil || h.affiliateService == nil {
|
|
||||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
|
||||||
}
|
|
||||||
return h.affiliateService, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAffiliate returns the current user's affiliate details.
|
// GetAffiliate returns the current user's affiliate details.
|
||||||
// GET /api/v1/user/aff
|
// GET /api/v1/user/aff
|
||||||
func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
||||||
@@ -184,13 +171,7 @@ func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
affiliateSvc, err := h.affiliateServiceOrErr()
|
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
detail, err := affiliateSvc.GetAffiliateDetail(c.Request.Context(), subject.UserID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
@@ -207,13 +188,7 @@ func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
affiliateSvc, err := h.affiliateServiceOrErr()
|
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
|
||||||
if err != nil {
|
|
||||||
response.ErrorFrom(c, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
transferred, balance, err := affiliateSvc.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
|||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
emailService := service.NewEmailService(nil, emailCache)
|
emailService := service.NewEmailService(nil, emailCache)
|
||||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||||
|
|
||||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
|
|||||||
ExpireHour: 1,
|
ExpireHour: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
|
|||||||
ExpireHour: 1,
|
ExpireHour: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(recorder)
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
emailService := service.NewEmailService(nil, emailCache)
|
emailService := service.NewEmailService(nil, emailCache)
|
||||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||||
|
|
||||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
|||||||
Status: service.StatusActive,
|
Status: service.StatusActive,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||||
|
|
||||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -80,18 +80,6 @@ func ProvideSettingHandler(settingService *service.SettingService, buildInfo Bui
|
|||||||
return NewSettingHandler(settingService, buildInfo.Version)
|
return NewSettingHandler(settingService, buildInfo.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideUserHandler(
|
|
||||||
userService *service.UserService,
|
|
||||||
authService *service.AuthService,
|
|
||||||
emailService *service.EmailService,
|
|
||||||
emailCache service.EmailCache,
|
|
||||||
affiliateService *service.AffiliateService,
|
|
||||||
) *UserHandler {
|
|
||||||
handler := NewUserHandler(userService, authService, emailService, emailCache)
|
|
||||||
handler.SetAffiliateService(affiliateService)
|
|
||||||
return handler
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProvideHandlers creates the Handlers struct
|
// ProvideHandlers creates the Handlers struct
|
||||||
func ProvideHandlers(
|
func ProvideHandlers(
|
||||||
authHandler *AuthHandler,
|
authHandler *AuthHandler,
|
||||||
@@ -137,7 +125,7 @@ func ProvideHandlers(
|
|||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Top-level handlers
|
// Top-level handlers
|
||||||
NewAuthHandler,
|
NewAuthHandler,
|
||||||
ProvideUserHandler,
|
NewUserHandler,
|
||||||
NewAPIKeyHandler,
|
NewAPIKeyHandler,
|
||||||
NewUsageHandler,
|
NewUsageHandler,
|
||||||
NewRedeemHandler,
|
NewRedeemHandler,
|
||||||
|
|||||||
@@ -80,6 +80,76 @@ VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
|
|||||||
require.Equal(t, 1, ledgerCount)
|
require.Equal(t, 1, ledgerCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
|
||||||
|
// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
|
||||||
|
// that already carries a transaction (via dbent.NewTxContext), repo.withTx
|
||||||
|
// must reuse that tx rather than opening a nested one. If this invariant
|
||||||
|
// breaks, AccrueQuota would commit independently and survive a rollback of
|
||||||
|
// the outer tx, which would violate payment_fulfillment's all-or-nothing
|
||||||
|
// semantics.
|
||||||
|
func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
outerTx, err := integrationEntClient.Tx(ctx)
|
||||||
|
require.NoError(t, err, "begin outer tx")
|
||||||
|
// Defensive cleanup: if any require.* below fires before the explicit
|
||||||
|
// Rollback, this prevents the tx from leaking until container teardown.
|
||||||
|
// Rollback is idempotent at the driver level (extra rollback returns an
|
||||||
|
// error we ignore).
|
||||||
|
t.Cleanup(func() { _ = outerTx.Rollback() })
|
||||||
|
client := outerTx.Client()
|
||||||
|
txCtx := dbent.NewTxContext(ctx, outerTx)
|
||||||
|
|
||||||
|
inviter := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Concurrency: 5,
|
||||||
|
})
|
||||||
|
invitee := mustCreateUser(t, client, &service.User{
|
||||||
|
Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
|
||||||
|
PasswordHash: "hash",
|
||||||
|
Role: service.RoleUser,
|
||||||
|
Status: service.StatusActive,
|
||||||
|
Concurrency: 5,
|
||||||
|
})
|
||||||
|
|
||||||
|
repo := NewAffiliateRepository(client, integrationDB)
|
||||||
|
_, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, bound, "invitee must bind to inviter")
|
||||||
|
|
||||||
|
applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, applied, "AccrueQuota must report applied=true")
|
||||||
|
|
||||||
|
// Visible inside the outer tx.
|
||||||
|
innerQuota := querySingleFloat(t, txCtx, client,
|
||||||
|
"SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
|
||||||
|
require.InDelta(t, 3.5, innerQuota, 1e-9)
|
||||||
|
|
||||||
|
// Roll back the outer tx; if AccrueQuota had opened its own inner tx and
|
||||||
|
// committed it, the rows would still be visible to the global client.
|
||||||
|
require.NoError(t, outerTx.Rollback())
|
||||||
|
|
||||||
|
rows, err := integrationEntClient.QueryContext(ctx,
|
||||||
|
"SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
|
||||||
|
inviter.ID, invitee.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
var postRollbackCount int
|
||||||
|
require.NoError(t, rows.Scan(&postRollbackCount))
|
||||||
|
require.Equal(t, 0, postRollbackCount,
|
||||||
|
"AccrueQuota must propagate the outer tx — found persisted rows after rollback")
|
||||||
|
}
|
||||||
|
|
||||||
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
|
func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tx := testEntTx(t)
|
tx := testEntTx(t)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
admin := &service.User{
|
admin := &service.User{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
|||||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||||
|
|
||||||
userRepo := &stubJWTUserRepo{users: users}
|
userRepo := &stubJWTUserRepo{users: users}
|
||||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||||
|
|
||||||
@@ -143,7 +143,7 @@ func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
|
|||||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||||
|
|
||||||
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
|
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
|
||||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
toucher := &recordingActivityToucher{}
|
toucher := &recordingActivityToucher{}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -20,8 +21,32 @@ var (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
affiliateInviteesLimit = 100
|
affiliateInviteesLimit = 100
|
||||||
|
// affiliateCodeFormatLength must stay in sync with repository.affiliateCodeLength.
|
||||||
|
affiliateCodeFormatLength = 12
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// affiliateCodeValidChar is a 256-entry lookup table mirroring the charset used
|
||||||
|
// by the repository's generateAffiliateCode (A-Z minus I/O, digits 2-9).
|
||||||
|
var affiliateCodeValidChar = func() [256]bool {
|
||||||
|
var tbl [256]bool
|
||||||
|
for _, c := range []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789") {
|
||||||
|
tbl[c] = true
|
||||||
|
}
|
||||||
|
return tbl
|
||||||
|
}()
|
||||||
|
|
||||||
|
func isValidAffiliateCodeFormat(code string) bool {
|
||||||
|
if len(code) != affiliateCodeFormatLength {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 0; i < len(code); i++ {
|
||||||
|
if !affiliateCodeValidChar[code[i]] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type AffiliateSummary struct {
|
type AffiliateSummary struct {
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
AffCode string `json:"aff_code"`
|
AffCode string `json:"aff_code"`
|
||||||
@@ -110,6 +135,9 @@ func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64,
|
|||||||
if code == "" {
|
if code == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if !isValidAffiliateCodeFormat(code) {
|
||||||
|
return ErrAffiliateCodeInvalid
|
||||||
|
}
|
||||||
if s == nil || s.repo == nil {
|
if s == nil || s.repo == nil {
|
||||||
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||||
}
|
}
|
||||||
@@ -279,10 +307,8 @@ func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID
|
|||||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
}
|
}
|
||||||
if s.billingCacheService != nil {
|
if s.billingCacheService != nil {
|
||||||
go func() {
|
if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
|
||||||
defer cancel()
|
}
|
||||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,3 +57,35 @@ func TestMaskEmail(t *testing.T) {
|
|||||||
require.Equal(t, "x***@d***", maskEmail("x@domain"))
|
require.Equal(t, "x***@d***", maskEmail("x@domain"))
|
||||||
require.Equal(t, "", maskEmail(""))
|
require.Equal(t, "", maskEmail(""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsValidAffiliateCodeFormat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"valid canonical", "ABCDEFGHJKLM", true},
|
||||||
|
{"valid all digits 2-9", "234567892345", true},
|
||||||
|
{"valid mixed", "A2B3C4D5E6F7", true},
|
||||||
|
{"too short", "ABCDEFGHJKL", false},
|
||||||
|
{"too long", "ABCDEFGHJKLMN", false},
|
||||||
|
{"contains excluded letter I", "IBCDEFGHJKLM", false},
|
||||||
|
{"contains excluded letter O", "OBCDEFGHJKLM", false},
|
||||||
|
{"contains excluded digit 0", "0BCDEFGHJKLM", false},
|
||||||
|
{"contains excluded digit 1", "1BCDEFGHJKLM", false},
|
||||||
|
{"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"12-byte utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // 6×2 bytes = 12 bytes, bytes out of charset
|
||||||
|
{"ascii punctuation", "ABCDEFGHJK.M", false},
|
||||||
|
{"whitespace", "ABCDEFGHJK M", false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ func newOAuthEmailFlowAuthService(
|
|||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func NewAuthService(
|
|||||||
emailQueueService *EmailQueueService,
|
emailQueueService *EmailQueueService,
|
||||||
promoService *PromoService,
|
promoService *PromoService,
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||||
|
affiliateService *AffiliateService,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
return &AuthService{
|
return &AuthService{
|
||||||
entClient: entClient,
|
entClient: entClient,
|
||||||
@@ -111,6 +112,7 @@ func NewAuthService(
|
|||||||
turnstileService: turnstileService,
|
turnstileService: turnstileService,
|
||||||
emailQueueService: emailQueueService,
|
emailQueueService: emailQueueService,
|
||||||
promoService: promoService,
|
promoService: promoService,
|
||||||
|
affiliateService: affiliateService,
|
||||||
defaultSubAssigner: defaultSubAssigner,
|
defaultSubAssigner: defaultSubAssigner,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,26 +124,13 @@ func (s *AuthService) EntClient() *dbent.Client {
|
|||||||
return s.entClient
|
return s.entClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthService) SetAffiliateService(affiliateService *AffiliateService) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.affiliateService = affiliateService
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register 用户注册,返回token和用户
|
// Register 用户注册,返回token和用户
|
||||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
|
// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
|
||||||
// affiliateCode 使用可选参数以兼容旧调用方。
|
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
|
||||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string, affiliateCode ...string) (string, *User, error) {
|
|
||||||
affiliateCodeRaw := ""
|
|
||||||
if len(affiliateCode) > 0 {
|
|
||||||
affiliateCodeRaw = affiliateCode[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||||
return "", nil, ErrRegDisabled
|
return "", nil, ErrRegDisabled
|
||||||
@@ -241,7 +230,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
|||||||
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
|
if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
|
||||||
}
|
}
|
||||||
if code := strings.TrimSpace(affiliateCodeRaw); code != "" {
|
if code := strings.TrimSpace(affiliateCode); code != "" {
|
||||||
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
|
if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
|
||||||
// 邀请返利码绑定失败不影响注册,只记录日志
|
// 邀请返利码绑定失败不影响注册,只记录日志
|
||||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
|
logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
|||||||
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
emailSvc = service.NewEmailService(settingRepo, emailCache)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
|
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
|
||||||
return svc, repo, client
|
return svc, repo, client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,7 +467,7 @@ func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *t
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
emailService := service.NewEmailService(nil, cache)
|
emailService := service.NewEmailService(nil, cache)
|
||||||
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
|
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
|
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
|
||||||
ID: 41,
|
ID: 41,
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
|
|||||||
values: settings,
|
values: settings,
|
||||||
}, cfg)
|
}, cfg)
|
||||||
|
|
||||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner)
|
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
|
||||||
return svc, repo, client
|
return svc, repo, client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
|||||||
nil,
|
nil,
|
||||||
nil, // promoService
|
nil, // promoService
|
||||||
nil, // defaultSubAssigner
|
nil, // defaultSubAssigner
|
||||||
|
nil, // affiliateService
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
|||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
// 应返回服务不可用错误,而不是允许绕过验证
|
// 应返回服务不可用错误,而不是允许绕过验证
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
|
||||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
|||||||
SettingKeyEmailVerifyEnabled: "true",
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
}, cache)
|
}, cache)
|
||||||
|
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
|
||||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
|||||||
SettingKeyEmailVerifyEnabled: "true",
|
SettingKeyEmailVerifyEnabled: "true",
|
||||||
}, cache)
|
}, cache)
|
||||||
|
|
||||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
|
||||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||||
require.ErrorContains(t, err, "verify code")
|
require.ErrorContains(t, err, "verify code")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
|
|||||||
nil, // emailQueueService
|
nil, // emailQueueService
|
||||||
nil, // promoService
|
nil, // promoService
|
||||||
nil, // defaultSubAssigner
|
nil, // defaultSubAssigner
|
||||||
|
nil, // affiliateService
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -184,19 +184,12 @@ type PaymentService struct {
|
|||||||
affiliateService *AffiliateService
|
affiliateService *AffiliateService
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository, affiliateService *AffiliateService) *PaymentService {
|
||||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo, affiliateService: affiliateService}
|
||||||
svc.resumeService = psNewPaymentResumeService(configService)
|
svc.resumeService = psNewPaymentResumeService(configService)
|
||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PaymentService) SetAffiliateService(affiliateService *AffiliateService) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.affiliateService = affiliateService
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Provider Registry ---
|
// --- Provider Registry ---
|
||||||
|
|
||||||
// EnsureProviders lazily initializes the provider registry on first call.
|
// EnsureProviders lazily initializes the provider registry on first call.
|
||||||
|
|||||||
@@ -391,53 +391,6 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
|||||||
return svc
|
return svc
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideAuthService(
|
|
||||||
entClient *dbent.Client,
|
|
||||||
userRepo UserRepository,
|
|
||||||
redeemRepo RedeemCodeRepository,
|
|
||||||
refreshTokenCache RefreshTokenCache,
|
|
||||||
cfg *config.Config,
|
|
||||||
settingService *SettingService,
|
|
||||||
emailService *EmailService,
|
|
||||||
turnstileService *TurnstileService,
|
|
||||||
emailQueueService *EmailQueueService,
|
|
||||||
promoService *PromoService,
|
|
||||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
|
||||||
affiliateService *AffiliateService,
|
|
||||||
) *AuthService {
|
|
||||||
svc := NewAuthService(
|
|
||||||
entClient,
|
|
||||||
userRepo,
|
|
||||||
redeemRepo,
|
|
||||||
refreshTokenCache,
|
|
||||||
cfg,
|
|
||||||
settingService,
|
|
||||||
emailService,
|
|
||||||
turnstileService,
|
|
||||||
emailQueueService,
|
|
||||||
promoService,
|
|
||||||
defaultSubAssigner,
|
|
||||||
)
|
|
||||||
svc.SetAffiliateService(affiliateService)
|
|
||||||
return svc
|
|
||||||
}
|
|
||||||
|
|
||||||
func ProvidePaymentService(
|
|
||||||
entClient *dbent.Client,
|
|
||||||
registry *payment.Registry,
|
|
||||||
loadBalancer payment.LoadBalancer,
|
|
||||||
redeemService *RedeemService,
|
|
||||||
subscriptionSvc *SubscriptionService,
|
|
||||||
configService *PaymentConfigService,
|
|
||||||
userRepo UserRepository,
|
|
||||||
groupRepo GroupRepository,
|
|
||||||
affiliateService *AffiliateService,
|
|
||||||
) *PaymentService {
|
|
||||||
svc := NewPaymentService(entClient, registry, loadBalancer, redeemService, subscriptionSvc, configService, userRepo, groupRepo)
|
|
||||||
svc.SetAffiliateService(affiliateService)
|
|
||||||
return svc
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
|
// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies.
|
||||||
func ProvideBillingCacheService(
|
func ProvideBillingCacheService(
|
||||||
cache BillingCache,
|
cache BillingCache,
|
||||||
@@ -454,7 +407,7 @@ func ProvideBillingCacheService(
|
|||||||
// ProviderSet is the Wire provider set for all services
|
// ProviderSet is the Wire provider set for all services
|
||||||
var ProviderSet = wire.NewSet(
|
var ProviderSet = wire.NewSet(
|
||||||
// Core services
|
// Core services
|
||||||
ProvideAuthService,
|
NewAuthService,
|
||||||
NewUserService,
|
NewUserService,
|
||||||
NewAPIKeyService,
|
NewAPIKeyService,
|
||||||
ProvideAPIKeyAuthCacheInvalidator,
|
ProvideAPIKeyAuthCacheInvalidator,
|
||||||
@@ -535,7 +488,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewModelPricingResolver,
|
NewModelPricingResolver,
|
||||||
NewAffiliateService,
|
NewAffiliateService,
|
||||||
ProvidePaymentConfigService,
|
ProvidePaymentConfigService,
|
||||||
ProvidePaymentService,
|
NewPaymentService,
|
||||||
ProvidePaymentOrderExpiryService,
|
ProvidePaymentOrderExpiryService,
|
||||||
ProvideBalanceNotifyService,
|
ProvideBalanceNotifyService,
|
||||||
ProvideChannelMonitorService,
|
ProvideChannelMonitorService,
|
||||||
|
|||||||
Reference in New Issue
Block a user