//go:build unit package service import ( "context" "database/sql" "testing" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" _ "modernc.org/sqlite" ) type paymentOrderLifecycleQueryProvider struct { lastQueryTradeNo string queryCalls int responses []*payment.QueryOrderResponse resp *payment.QueryOrderResponse } type paymentOrderLifecycleRedeemRepo struct { codesByCode map[string]*RedeemCode useCalls []struct { id int64 userID int64 } } func (p *paymentOrderLifecycleQueryProvider) Name() string { return "payment-order-lifecycle-query-provider" } func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay } func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { panic("unexpected call") } func (p *paymentOrderLifecycleQueryProvider) QueryOrder(_ context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { p.lastQueryTradeNo = tradeNo p.queryCalls++ if len(p.responses) > 0 { resp := p.responses[0] if len(p.responses) > 1 { p.responses = p.responses[1:] } return resp, nil } return p.resp, nil } func (p *paymentOrderLifecycleQueryProvider) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { panic("unexpected call") } func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) CreateBatch(context.Context, []RedeemCode) error { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) GetByID(_ context.Context, id int64) (*RedeemCode, error) { for _, code := range r.codesByCode { if code.ID != id { continue } cloned := *code return &cloned, nil } return nil, ErrRedeemCodeNotFound } func (r *paymentOrderLifecycleRedeemRepo) GetByCode(_ context.Context, code string) (*RedeemCode, error) { redeemCode, ok := r.codesByCode[code] if !ok { return nil, ErrRedeemCodeNotFound } cloned := *redeemCode return &cloned, nil } func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) error { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) Use(_ context.Context, id, userID int64) error { for code, redeemCode := range r.codesByCode { if redeemCode.ID != id { continue } now := time.Now().UTC() redeemCode.Status = StatusUsed redeemCode.UsedBy = &userID redeemCode.UsedAt = &now r.codesByCode[code] = redeemCode r.useCalls = append(r.useCalls, struct { id int64 userID int64 }{id: id, userID: userID}) return nil } return ErrRedeemCodeNotFound } func (r *paymentOrderLifecycleRedeemRepo) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) ListByUser(context.Context, int64, int) ([]RedeemCode, error) { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) { panic("unexpected call") } func (r *paymentOrderLifecycleRedeemRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) { panic("unexpected call") } func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) user, err := client.User.Create(). SetEmail("checkpaid@example.com"). SetPasswordHash("hash"). SetUsername("checkpaid-user"). Save(ctx) require.NoError(t, err) order, err := client.PaymentOrder.Create(). SetUserID(user.ID). SetUserEmail(user.Email). SetUserName(user.Username). SetAmount(88). SetPayAmount(88). SetFeeRate(0). SetRechargeCode("CHECKPAID-UPSTREAM-TRADE-NO"). SetOutTradeNo("sub2_checkpaid_trade_no_missing"). SetPaymentType(payment.TypeAlipay). SetPaymentTradeNo(""). SetOrderType(payment.OrderTypeBalance). SetStatus(OrderStatusPending). SetExpiresAt(time.Now().Add(time.Hour)). SetClientIP("127.0.0.1"). SetSrcHost("api.example.com"). Save(ctx) require.NoError(t, err) userRepo := &mockUserRepo{ getByIDUser: &User{ ID: user.ID, Email: user.Email, Username: user.Username, Balance: 0, }, } userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { require.Equal(t, user.ID, id) if userRepo.getByIDUser != nil { userRepo.getByIDUser.Balance += amount } return nil } redeemRepo := &paymentOrderLifecycleRedeemRepo{ codesByCode: map[string]*RedeemCode{ order.RechargeCode: { ID: 1, Code: order.RechargeCode, Type: RedeemTypeBalance, Value: order.Amount, Status: StatusUnused, }, }, } redeemService := NewRedeemService( redeemRepo, userRepo, nil, nil, nil, client, nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ resp: &payment.QueryOrderResponse{ TradeNo: "upstream-trade-123", Status: payment.ProviderStatusPaid, Amount: 88, }, } registry.Register(provider) svc := &PaymentService{ entClient: client, registry: registry, redeemService: redeemService, userRepo: userRepo, providersLoaded: true, } got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) require.NoError(t, err) require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) require.Equal(t, OrderStatusCompleted, got.Status) require.Equal(t, "upstream-trade-123", got.PaymentTradeNo) reloaded, err := client.PaymentOrder.Get(ctx, order.ID) require.NoError(t, err) require.Equal(t, OrderStatusCompleted, reloaded.Status) require.Equal(t, "upstream-trade-123", reloaded.PaymentTradeNo) require.Equal(t, 88.0, userRepo.getByIDUser.Balance) require.Len(t, redeemRepo.useCalls, 1) require.Equal(t, int64(1), redeemRepo.useCalls[0].id) require.Equal(t, user.ID, redeemRepo.useCalls[0].userID) } func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) user, err := client.User.Create(). SetEmail("checkpaid-retry@example.com"). SetPasswordHash("hash"). SetUsername("checkpaid-retry-user"). Save(ctx) require.NoError(t, err) order, err := client.PaymentOrder.Create(). SetUserID(user.ID). SetUserEmail(user.Email). SetUserName(user.Username). SetAmount(88). SetPayAmount(88). SetFeeRate(0). SetRechargeCode("CHECKPAID-UPSTREAM-RETRY"). SetOutTradeNo("sub2_checkpaid_retry_zero_amount"). SetPaymentType(payment.TypeAlipay). SetPaymentTradeNo(""). SetOrderType(payment.OrderTypeBalance). SetStatus(OrderStatusPending). SetExpiresAt(time.Now().Add(time.Hour)). SetClientIP("127.0.0.1"). SetSrcHost("api.example.com"). Save(ctx) require.NoError(t, err) userRepo := &mockUserRepo{ getByIDUser: &User{ ID: user.ID, Email: user.Email, Username: user.Username, Balance: 0, }, } userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { require.Equal(t, user.ID, id) if userRepo.getByIDUser != nil { userRepo.getByIDUser.Balance += amount } return nil } redeemRepo := &paymentOrderLifecycleRedeemRepo{ codesByCode: map[string]*RedeemCode{ order.RechargeCode: { ID: 1, Code: order.RechargeCode, Type: RedeemTypeBalance, Value: order.Amount, Status: StatusUnused, }, }, } redeemService := NewRedeemService( redeemRepo, userRepo, nil, nil, nil, client, nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ responses: []*payment.QueryOrderResponse{ { TradeNo: "upstream-trade-zero", Status: payment.ProviderStatusPaid, Amount: 0, }, { TradeNo: "upstream-trade-retry", Status: payment.ProviderStatusPaid, Amount: 88, }, }, } registry.Register(provider) svc := &PaymentService{ entClient: client, registry: registry, redeemService: redeemService, userRepo: userRepo, providersLoaded: true, } got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) require.NoError(t, err) require.Equal(t, 2, provider.queryCalls) require.Equal(t, OrderStatusCompleted, got.Status) require.Equal(t, "upstream-trade-retry", got.PaymentTradeNo) } func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) user, err := client.User.Create(). SetEmail("checkpaid-zero-amount@example.com"). SetPasswordHash("hash"). SetUsername("checkpaid-zero-amount-user"). Save(ctx) require.NoError(t, err) order, err := client.PaymentOrder.Create(). SetUserID(user.ID). SetUserEmail(user.Email). SetUserName(user.Username). SetAmount(88). SetPayAmount(88). SetFeeRate(0). SetRechargeCode("CHECKPAID-ZERO-AMOUNT"). SetOutTradeNo("sub2_checkpaid_zero_amount"). SetPaymentType(payment.TypeAlipay). SetPaymentTradeNo(""). SetOrderType(payment.OrderTypeBalance). SetStatus(OrderStatusPending). SetExpiresAt(time.Now().Add(time.Hour)). SetClientIP("127.0.0.1"). SetSrcHost("api.example.com"). Save(ctx) require.NoError(t, err) userRepo := &mockUserRepo{ getByIDUser: &User{ ID: user.ID, Email: user.Email, Username: user.Username, Balance: 0, }, } redeemRepo := &paymentOrderLifecycleRedeemRepo{ codesByCode: map[string]*RedeemCode{ order.RechargeCode: { ID: 1, Code: order.RechargeCode, Type: RedeemTypeBalance, Value: order.Amount, Status: StatusUnused, }, }, } redeemService := NewRedeemService( redeemRepo, userRepo, nil, nil, nil, client, nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ resp: &payment.QueryOrderResponse{ TradeNo: "upstream-trade-zero", Status: payment.ProviderStatusPaid, Amount: 0, }, } registry.Register(provider) svc := &PaymentService{ entClient: client, registry: registry, redeemService: redeemService, userRepo: userRepo, providersLoaded: true, } got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) require.NoError(t, err) require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) require.Equal(t, OrderStatusPending, got.Status) require.Empty(t, got.PaymentTradeNo) reloaded, err := client.PaymentOrder.Get(ctx, order.ID) require.NoError(t, err) require.Equal(t, OrderStatusPending, reloaded.Status) require.Empty(t, reloaded.PaymentTradeNo) require.Equal(t, 0.0, userRepo.getByIDUser.Balance) require.Empty(t, redeemRepo.useCalls) } func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) user, err := client.User.Create(). SetEmail("checkpaid-existing-trade@example.com"). SetPasswordHash("hash"). SetUsername("checkpaid-existing-trade-user"). Save(ctx) require.NoError(t, err) order, err := client.PaymentOrder.Create(). SetUserID(user.ID). SetUserEmail(user.Email). SetUserName(user.Username). SetAmount(88). SetPayAmount(88). SetFeeRate(0). SetRechargeCode("CHECKPAID-EXISTING-TRADE-NO"). SetOutTradeNo("sub2_checkpaid_use_out_trade_no"). SetPaymentType(payment.TypeAlipay). SetPaymentTradeNo("upstream-trade-existing"). SetOrderType(payment.OrderTypeBalance). SetStatus(OrderStatusPending). SetExpiresAt(time.Now().Add(time.Hour)). SetClientIP("127.0.0.1"). SetSrcHost("api.example.com"). Save(ctx) require.NoError(t, err) userRepo := &mockUserRepo{ getByIDUser: &User{ ID: user.ID, Email: user.Email, Username: user.Username, Balance: 0, }, } userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { require.Equal(t, user.ID, id) if userRepo.getByIDUser != nil { userRepo.getByIDUser.Balance += amount } return nil } redeemRepo := &paymentOrderLifecycleRedeemRepo{ codesByCode: map[string]*RedeemCode{ order.RechargeCode: { ID: 1, Code: order.RechargeCode, Type: RedeemTypeBalance, Value: order.Amount, Status: StatusUnused, }, }, } redeemService := NewRedeemService( redeemRepo, userRepo, nil, nil, nil, client, nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ resp: &payment.QueryOrderResponse{ TradeNo: "upstream-trade-existing", Status: payment.ProviderStatusPaid, Amount: 88, }, } registry.Register(provider) svc := &PaymentService{ entClient: client, registry: registry, redeemService: redeemService, userRepo: userRepo, providersLoaded: true, } got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) require.NoError(t, err) require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) require.Equal(t, "upstream-trade-existing", got.PaymentTradeNo) } func TestPaymentOrderAllowsRegistryFallbackOnlyForLegacyOrdersWithoutPinnedProviderState(t *testing.T) { t.Parallel() require.True(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ PaymentType: payment.TypeAlipay, })) instanceID := "12" require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ PaymentType: payment.TypeAlipay, ProviderInstanceID: &instanceID, })) require.False(t, paymentOrderAllowsRegistryFallback(&dbent.PaymentOrder{ PaymentType: payment.TypeAlipay, ProviderSnapshot: map[string]any{ "schema_version": 2, "provider_instance_id": "12", }, })) } func TestPaymentOrderQueryReferenceUsesOutTradeNoForOfficialProviders(t *testing.T) { t.Parallel() order := &dbent.PaymentOrder{ PaymentType: payment.TypeWxpay, OutTradeNo: "sub2_out_trade_no", PaymentTradeNo: "wx-transaction-id", } require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, &paymentOrderLifecycleQueryProvider{})) require.Equal(t, "sub2_out_trade_no", paymentOrderQueryReference(order, paymentFulfillmentTestProvider{ key: payment.TypeWxpay, })) } func newPaymentOrderLifecycleTestClient(t *testing.T) *dbent.Client { t.Helper() db, err := sql.Open("sqlite", "file:payment_order_lifecycle?mode=memory&cache=shared&_fk=1") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) t.Cleanup(func() { _ = client.Close() }) return client }