Files
sub2api/backend/internal/service/payment_resume_service_test.go

382 lines
12 KiB
Go

//go:build unit
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/url"
"strconv"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/pay/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestBuildPaymentReturnURL(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("https://example.com/payment/result?from=checkout#fragment", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
parsed, err := url.Parse(got)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
if parsed.Fragment != "" {
t.Fatalf("buildPaymentReturnURL should strip fragments, got %q", parsed.Fragment)
}
query := parsed.Query()
if query.Get("from") != "checkout" {
t.Fatalf("expected original query to be preserved, got %q", query.Get("from"))
}
if query.Get("order_id") != strconv.FormatInt(42, 10) {
t.Fatalf("order_id = %q", query.Get("order_id"))
}
if query.Get("resume_token") != "resume-token" {
t.Fatalf("resume_token = %q", query.Get("resume_token"))
}
if query.Get("status") != "success" {
t.Fatalf("status = %q", query.Get("status"))
}
}
func TestBuildPaymentReturnURLEmptyBase(t *testing.T) {
t.Parallel()
got, err := buildPaymentReturnURL("", 42, "resume-token")
if err != nil {
t.Fatalf("buildPaymentReturnURL returned error: %v", err)
}
if got != "" {
t.Fatalf("buildPaymentReturnURL = %q, want empty string", got)
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestCreateTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateToken(ResumeTokenClaims{OrderID: 42})
if err == nil {
t.Fatal("CreateToken should reject missing signing key")
}
}
func TestParseTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, ResumeTokenClaims{OrderID: 42, UserID: 7})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseToken(token)
if err == nil {
t.Fatal("ParseToken should reject tokens when signing key is missing")
}
}
func TestWeChatPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeSubscription,
PlanID: 7,
RedirectTo: "/purchase?from=wechat",
Scope: "snsapi_base",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err)
}
claims, err := svc.ParseWeChatPaymentResumeToken(token)
if err != nil {
t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err)
}
if claims.OpenID != "openid-123" || claims.PaymentType != payment.TypeWxpay {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.Amount != "12.50" || claims.OrderType != payment.OrderTypeSubscription || claims.PlanID != 7 {
t.Fatalf("claims payment context mismatch: %+v", claims)
}
if claims.RedirectTo != "/purchase?from=wechat" || claims.Scope != "snsapi_base" {
t.Fatalf("claims redirect/scope mismatch: %+v", claims)
}
}
func TestCreateWeChatPaymentResumeTokenRejectsMissingSigningKey(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService(nil)
_, err := svc.CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{OpenID: "openid-123"})
if err == nil {
t.Fatal("CreateWeChatPaymentResumeToken should reject missing signing key")
}
}
func TestParseWeChatPaymentResumeTokenRejectsFallbackSignedTokenWhenSigningKeyMissing(t *testing.T) {
t.Parallel()
token := mustCreateFallbackSignedToken(t, WeChatPaymentResumeClaims{
TokenType: wechatPaymentResumeTokenType,
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
})
svc := NewPaymentResumeService(nil)
_, err := svc.ParseWeChatPaymentResumeToken(token)
if err == nil {
t.Fatal("ParseWeChatPaymentResumeToken should reject tokens when signing key is missing")
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "false",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject disabled visible method")
}
}
type paymentSettingRepoStub struct {
values map[string]string
}
func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}
func mustCreateFallbackSignedToken(t *testing.T, claims any) string {
t.Helper()
payload, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey))
_, _ = mac.Write([]byte(encodedPayload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return encodedPayload + "." + signature
}