refactor(affiliate): tighten DI and harden inviter code validation
- Drop SetAffiliateService setters and ProvideAuthService / ProvidePaymentService / ProvideUserHandler wrappers in favor of direct Wire constructor injection. AffiliateService has no back-edge to Auth/Payment/User, so the indirection was never required. - Change RegisterWithVerification's variadic affiliateCode to a fixed parameter; adjust all call sites. - Validate aff_code length and charset in BindInviterByCode before any DB lookup, eliminating timing-side-channel and useless DB roundtrips on malformed input. - Make affiliate cache invalidation synchronous; surface Redis errors via the project logger instead of swallowing them in a detached goroutine. - Add an integration test guarding cross-layer tx propagation in AccrueQuota and a unit test pinning the aff_code format rules.
This commit is contained in:
@@ -2210,6 +2210,7 @@ CREATE TABLE IF NOT EXISTS user_avatars (
|
||||
nil,
|
||||
nil,
|
||||
options.defaultSubAssigner,
|
||||
nil,
|
||||
)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||
var totpSvc *service.TotpService
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := &AuthHandler{authService: authService}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -1399,6 +1399,7 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
|
||||
@@ -117,7 +117,7 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -215,7 +215,7 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing
|
||||
require.NoError(t, err)
|
||||
|
||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -302,7 +302,7 @@ func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *t
|
||||
require.NoError(t, err)
|
||||
|
||||
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -342,7 +342,7 @@ func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
|
||||
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
|
||||
h := NewPaymentHandler(paymentSvc, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -28,22 +27,17 @@ func NewUserHandler(
|
||||
authService *service.AuthService,
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
affiliateService *service.AffiliateService,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
userService: userService,
|
||||
authService: authService,
|
||||
emailService: emailService,
|
||||
emailCache: emailCache,
|
||||
affiliateService: affiliateService,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UserHandler) SetAffiliateService(affiliateService *service.AffiliateService) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.affiliateService = affiliateService
|
||||
}
|
||||
|
||||
// ChangePasswordRequest represents the change password request payload
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required"`
|
||||
@@ -168,13 +162,6 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
response.Success(c, profileResp)
|
||||
}
|
||||
|
||||
func (h *UserHandler) affiliateServiceOrErr() (*service.AffiliateService, error) {
|
||||
if h == nil || h.affiliateService == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
|
||||
}
|
||||
return h.affiliateService, nil
|
||||
}
|
||||
|
||||
// GetAffiliate returns the current user's affiliate details.
|
||||
// GET /api/v1/user/aff
|
||||
func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
||||
@@ -184,13 +171,7 @@ func (h *UserHandler) GetAffiliate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
affiliateSvc, err := h.affiliateServiceOrErr()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := affiliateSvc.GetAffiliateDetail(c.Request.Context(), subject.UserID)
|
||||
detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -207,13 +188,7 @@ func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
affiliateSvc, err := h.affiliateServiceOrErr()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
transferred, balance, err := affiliateSvc.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
|
||||
transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -142,7 +142,7 @@ func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -200,7 +200,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -283,7 +283,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -362,7 +362,7 @@ func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIde
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -511,8 +511,8 @@ func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -566,7 +566,7 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -625,8 +625,8 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -668,8 +668,8 @@ func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
@@ -712,8 +712,8 @@ func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t
|
||||
},
|
||||
}
|
||||
emailService := service.NewEmailService(nil, emailCache)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
|
||||
authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -750,7 +750,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
|
||||
|
||||
body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -80,18 +80,6 @@ func ProvideSettingHandler(settingService *service.SettingService, buildInfo Bui
|
||||
return NewSettingHandler(settingService, buildInfo.Version)
|
||||
}
|
||||
|
||||
func ProvideUserHandler(
|
||||
userService *service.UserService,
|
||||
authService *service.AuthService,
|
||||
emailService *service.EmailService,
|
||||
emailCache service.EmailCache,
|
||||
affiliateService *service.AffiliateService,
|
||||
) *UserHandler {
|
||||
handler := NewUserHandler(userService, authService, emailService, emailCache)
|
||||
handler.SetAffiliateService(affiliateService)
|
||||
return handler
|
||||
}
|
||||
|
||||
// ProvideHandlers creates the Handlers struct
|
||||
func ProvideHandlers(
|
||||
authHandler *AuthHandler,
|
||||
@@ -137,7 +125,7 @@ func ProvideHandlers(
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Top-level handlers
|
||||
NewAuthHandler,
|
||||
ProvideUserHandler,
|
||||
NewUserHandler,
|
||||
NewAPIKeyHandler,
|
||||
NewUsageHandler,
|
||||
NewRedeemHandler,
|
||||
|
||||
Reference in New Issue
Block a user