//go:build unit package provider import ( "context" "net/url" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/wechatpay-apiv3/wechatpay-go/core" "github.com/wechatpay-apiv3/wechatpay-go/services/payments" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" ) func TestMapWxState(t *testing.T) { t.Parallel() tests := []struct { name string input string want string }{ { name: "SUCCESS maps to paid", input: wxpayTradeStateSuccess, want: payment.ProviderStatusPaid, }, { name: "REFUND maps to refunded", input: wxpayTradeStateRefund, want: payment.ProviderStatusRefunded, }, { name: "CLOSED maps to failed", input: wxpayTradeStateClosed, want: payment.ProviderStatusFailed, }, { name: "PAYERROR maps to failed", input: wxpayTradeStatePayError, want: payment.ProviderStatusFailed, }, { name: "unknown state maps to pending", input: "NOTPAY", want: payment.ProviderStatusPending, }, { name: "empty string maps to pending", input: "", want: payment.ProviderStatusPending, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := mapWxState(tt.input) if got != tt.want { t.Errorf("mapWxState(%q) = %q, want %q", tt.input, got, tt.want) } }) } } func TestWxSV(t *testing.T) { t.Parallel() tests := []struct { name string input *string want string }{ { name: "nil pointer returns empty string", input: nil, want: "", }, { name: "non-nil pointer returns value", input: strPtr("hello"), want: "hello", }, { name: "pointer to empty string returns empty string", input: strPtr(""), want: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := wxSV(tt.input) if got != tt.want { t.Errorf("wxSV() = %q, want %q", got, tt.want) } }) } } func TestBuildWxpayTransactionMetadata(t *testing.T) { t.Parallel() tx := &payments.Transaction{ Appid: strPtr("wx-app-id"), Mchid: strPtr("mch-id"), TradeState: strPtr(wxpayTradeStateSuccess), Amount: &payments.Amount{ Currency: strPtr(wxpayCurrency), }, } metadata := buildWxpayTransactionMetadata(tx) if metadata[wxpayMetadataAppID] != "wx-app-id" { t.Fatalf("appid = %q", metadata[wxpayMetadataAppID]) } if metadata[wxpayMetadataMerchantID] != "mch-id" { t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID]) } if metadata[wxpayMetadataCurrency] != wxpayCurrency { t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency]) } if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess { t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState]) } } func strPtr(s string) *string { return &s } func TestFormatPEM(t *testing.T) { t.Parallel() tests := []struct { name string key string keyType string want string }{ { name: "raw key gets wrapped with headers", key: "MIIBIjANBgkqhki...", keyType: "PUBLIC KEY", want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----", }, { name: "already formatted key is returned as-is", key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----", keyType: "PRIVATE KEY", want: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----", }, { name: "key with leading/trailing whitespace is trimmed before check", key: " \n MIIBIjANBgkqhki... \n ", keyType: "PUBLIC KEY", want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----", }, { name: "already formatted key with whitespace is trimmed and returned", key: " -----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY----- ", keyType: "RSA PRIVATE KEY", want: "-----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY-----", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := formatPEM(tt.key, tt.keyType) if got != tt.want { t.Errorf("formatPEM(%q, %q) =\n%s\nwant:\n%s", tt.key, tt.keyType, got, tt.want) } }) } } func TestNewWxpay(t *testing.T) { t.Parallel() validConfig := map[string]string{ "appId": "wx1234567890", "mchId": "1234567890", "privateKey": "fake-private-key", "apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes "publicKey": "fake-public-key", "publicKeyId": "key-id-001", "certSerial": "SERIAL001", } // helper to clone and override config fields withOverride := func(overrides map[string]string) map[string]string { cfg := make(map[string]string, len(validConfig)) for k, v := range validConfig { cfg[k] = v } for k, v := range overrides { cfg[k] = v } return cfg } tests := []struct { name string config map[string]string wantErr bool errSubstr string }{ { name: "valid config succeeds", config: validConfig, wantErr: false, }, { name: "missing appId", config: withOverride(map[string]string{"appId": ""}), wantErr: true, errSubstr: "appId", }, { name: "missing mchId", config: withOverride(map[string]string{"mchId": ""}), wantErr: true, errSubstr: "mchId", }, { name: "missing privateKey", config: withOverride(map[string]string{"privateKey": ""}), wantErr: true, errSubstr: "privateKey", }, { name: "missing apiV3Key", config: withOverride(map[string]string{"apiV3Key": ""}), wantErr: true, errSubstr: "apiV3Key", }, { name: "missing publicKey", config: withOverride(map[string]string{"publicKey": ""}), wantErr: true, errSubstr: "publicKey", }, { name: "missing publicKeyId", config: withOverride(map[string]string{"publicKeyId": ""}), wantErr: true, errSubstr: "publicKeyId", }, { name: "apiV3Key too short", config: withOverride(map[string]string{"apiV3Key": "short"}), wantErr: true, errSubstr: "exactly 32 bytes", }, { name: "apiV3Key too long", config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes wantErr: true, errSubstr: "exactly 32 bytes", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := NewWxpay("test-instance", tt.config) if tt.wantErr { if err == nil { t.Fatal("expected error, got nil") } if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) { t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr) } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if got == nil { t.Fatal("expected non-nil Wxpay instance") } if got.instanceID != "test-instance" { t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance") } }) } } func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) { t.Parallel() resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{ OrderID: "sub2_42", PaymentType: payment.TypeWxpay, }) if err != nil { t.Fatalf("buildWxpayResultURL returned error: %v", err) } parsed, err := url.Parse(resultURL) if err != nil { t.Fatalf("url.Parse returned error: %v", err) } query := parsed.Query() if parsed.Path != wxpayResultPath { t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath) } if query.Get("resume_token") != "resume-42" { t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42") } if query.Get("order_id") != "42" { t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42") } if query.Get("out_trade_no") != "sub2_42" { t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42") } } func TestResolveWxpayJSAPIAppID(t *testing.T) { t.Parallel() tests := []struct { name string config map[string]string want string }{ { name: "prefers dedicated mp app id", config: map[string]string{ "mpAppId": "wx-mp-app", "appId": "wx-merchant-app", }, want: "wx-mp-app", }, { name: "falls back to merchant app id", config: map[string]string{ "appId": "wx-merchant-app", }, want: "wx-merchant-app", }, { name: "missing app ids returns empty", config: map[string]string{}, want: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want { t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want) } }) } } func TestResolveWxpayCreateMode(t *testing.T) { t.Parallel() tests := []struct { name string req payment.CreatePaymentRequest wantMode string wantErr string }{ { name: "desktop uses native", req: payment.CreatePaymentRequest{}, wantMode: wxpayModeNative, }, { name: "mobile uses h5 when client ip is present", req: payment.CreatePaymentRequest{ IsMobile: true, ClientIP: "203.0.113.10", }, wantMode: wxpayModeH5, }, { name: "mobile without client ip returns clear error", req: payment.CreatePaymentRequest{ IsMobile: true, }, wantErr: "requires client IP", }, { name: "openid uses jsapi mode", req: payment.CreatePaymentRequest{ OpenID: "openid-123", }, wantMode: wxpayModeJSAPI, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := resolveWxpayCreateMode(tt.req) if tt.wantErr != "" { if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), tt.wantErr) { t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr) } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if got != tt.wantMode { t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode) } }) } } func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) { origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment origNativePrepay := wxpayNativePrepay origH5Prepay := wxpayH5Prepay t.Cleanup(func() { wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay wxpayNativePrepay = origNativePrepay wxpayH5Prepay = origH5Prepay }) jsapiCalls := 0 nativeCalls := 0 h5Calls := 0 wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { jsapiCalls++ if got := wxSV(req.Payer.Openid); got != "openid-123" { t.Fatalf("openid = %q, want %q", got, "openid-123") } if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" { t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10") } return &jsapi.PrepayWithRequestPaymentResponse{ Appid: core.String("wx123"), TimeStamp: core.String("1712345678"), NonceStr: core.String("nonce-123"), Package: core.String("prepay_id=wx_prepay_123"), SignType: core.String("RSA"), PaySign: core.String("signed-payload"), }, nil, nil } wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { nativeCalls++ return &native.PrepayResponse{}, nil, nil } wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { h5Calls++ return &h5.PrepayResponse{}, nil, nil } provider := &Wxpay{ config: map[string]string{ "appId": "wx123", "mchId": "mch123", }, coreClient: &core.Client{}, } resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ OrderID: "sub2_88", Amount: "66.88", PaymentType: payment.TypeWxpay, NotifyURL: "https://merchant.example/payment/notify", OpenID: "openid-123", ClientIP: "203.0.113.10", }) if err != nil { t.Fatalf("unexpected error: %v", err) } if jsapiCalls != 1 { t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls) } if nativeCalls != 0 { t.Fatalf("native prepay calls = %d, want 0", nativeCalls) } if h5Calls != 0 { t.Fatalf("h5 prepay calls = %d, want 0", h5Calls) } if resp.ResultType != payment.CreatePaymentResultJSAPIReady { t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady) } if resp.JSAPI == nil { t.Fatal("expected jsapi payload, got nil") } if resp.JSAPI.AppID != "wx123" { t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123") } if resp.JSAPI.TimeStamp != "1712345678" { t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678") } if resp.JSAPI.NonceStr != "nonce-123" { t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123") } if resp.JSAPI.Package != "prepay_id=wx_prepay_123" { t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123") } if resp.JSAPI.SignType != "RSA" { t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA") } if resp.JSAPI.PaySign != "signed-payload" { t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload") } } func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) { origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment origNativePrepay := wxpayNativePrepay origH5Prepay := wxpayH5Prepay t.Cleanup(func() { wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay wxpayNativePrepay = origNativePrepay wxpayH5Prepay = origH5Prepay }) jsapiCalls := 0 nativeCalls := 0 h5Calls := 0 wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) { jsapiCalls++ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil } wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) { nativeCalls++ return &native.PrepayResponse{}, nil, nil } wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) { h5Calls++ if req.SceneInfo == nil { t.Fatal("expected scene_info, got nil") } if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" { t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10") } if req.SceneInfo.H5Info == nil { t.Fatal("expected scene_info.h5_info, got nil") } if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type { t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type) } if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" { t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API") } if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" { t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com") } return &h5.PrepayResponse{ H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"), }, nil, nil } provider := &Wxpay{ config: map[string]string{ "appId": "wx123", "mchId": "mch123", "h5AppName": "Sub2API", "h5AppUrl": "https://app.example.com", }, coreClient: &core.Client{}, } resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{ OrderID: "sub2_99", Amount: "66.88", PaymentType: payment.TypeWxpay, Subject: "Balance Recharge", NotifyURL: "https://merchant.example/payment/notify", ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99", ClientIP: "203.0.113.10", IsMobile: true, }) if err != nil { t.Fatalf("unexpected error: %v", err) } if jsapiCalls != 0 { t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls) } if nativeCalls != 0 { t.Fatalf("native prepay calls = %d, want 0", nativeCalls) } if h5Calls != 1 { t.Fatalf("h5 prepay calls = %d, want 1", h5Calls) } if !strings.Contains(resp.PayURL, "redirect_url=") { t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL) } }