diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go index 37bd38b2..e7d8aab9 100644 --- a/backend/internal/payment/provider/easypay.go +++ b/backend/internal/payment/provider/easypay.go @@ -25,6 +25,7 @@ const ( easypayStatusPaid = 1 easypayHTTPTimeout = 10 * time.Second maxEasypayResponseSize = 1 << 20 // 1MB + maxEasypayErrorSummary = 512 tradeStatusSuccess = "TRADE_SUCCESS" signTypeMD5 = "MD5" paymentModePopup = "popup" @@ -42,17 +43,55 @@ type EasyPay struct { // config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) { for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} { - if config[k] == "" { + if strings.TrimSpace(config[k]) == "" { return nil, fmt.Errorf("easypay config missing required key: %s", k) } } + cfg := make(map[string]string, len(config)) + for k, v := range config { + cfg[k] = v + } + cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"]) return &EasyPay{ instanceID: instanceID, - config: config, + config: cfg, httpClient: &http.Client{Timeout: easypayHTTPTimeout}, }, nil } +func normalizeEasyPayAPIBase(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return "" + } + if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" { + parsed.RawQuery = "" + parsed.Fragment = "" + parsed.RawPath = "" + parsed.Path = trimEasyPayEndpointPath(parsed.Path) + return strings.TrimRight(parsed.String(), "/") + } + return strings.TrimRight(trimEasyPayEndpointPath(base), "/") +} + +func trimEasyPayEndpointPath(path string) string { + path = strings.TrimRight(strings.TrimSpace(path), "/") + lower := strings.ToLower(path) + for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} { + if strings.HasSuffix(lower, endpoint) { + return strings.TrimRight(path[:len(path)-len(endpoint)], "/") + } + } + return path +} + +func (e *EasyPay) apiBase() string { + if e == nil { + return "" + } + return normalizeEasyPayAPIBase(e.config["apiBase"]) +} + func (e *EasyPay) Name() string { return "EasyPay" } func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay } func (e *EasyPay) SupportedTypes() []payment.PaymentType { @@ -104,8 +143,7 @@ func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*paym for k, v := range params { q.Set(k, v) } - base := strings.TrimRight(e.config["apiBase"], "/") - payURL := base + "/submit.php?" + q.Encode() + payURL := e.apiBase() + "/submit.php?" + q.Encode() return &payment.CreatePaymentResponse{PayURL: payURL}, nil } @@ -127,7 +165,7 @@ func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymen params["sign"] = easyPaySign(params, e.config["pkey"]) params["sign_type"] = signTypeMD5 - body, err := e.post(ctx, strings.TrimRight(e.config["apiBase"], "/")+"/mapi.php", params) + body, err := e.post(ctx, e.apiBase()+"/mapi.php", params) if err != nil { return nil, fmt.Errorf("easypay create: %w", err) } @@ -171,7 +209,7 @@ func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Quer "act": "order", "pid": e.config["pid"], "key": e.config["pkey"], "out_trade_no": tradeNo, } - body, err := e.post(ctx, e.config["apiBase"]+"/api.php", params) + body, err := e.post(ctx, e.apiBase()+"/api.php", params) if err != nil { return nil, fmt.Errorf("easypay query: %w", err) } @@ -234,25 +272,128 @@ func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[st } func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { - params := map[string]string{ - "pid": e.config["pid"], "key": e.config["pkey"], - "trade_no": req.TradeNo, "out_trade_no": req.OrderID, "money": req.Amount, + attempts := e.refundAttempts(req) + if len(attempts) == 0 { + return nil, fmt.Errorf("easypay refund missing order identifier") } - body, err := e.post(ctx, e.config["apiBase"]+"/api.php?act=refund", params) - if err != nil { - return nil, fmt.Errorf("easypay refund: %w", err) + var firstErr error + for i, attempt := range attempts { + body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params) + if err != nil { + return nil, fmt.Errorf("easypay refund request: %w", err) + } + if err := parseEasyPayRefundResponse(status, body); err != nil { + if firstErr == nil { + firstErr = err + } + if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) { + continue + } + return nil, err + } + return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil } + return nil, firstErr +} + +type easyPayRefundAttempt struct { + params map[string]string + refundID string +} + +func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt { + base := map[string]string{ + "pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount, + } + var attempts []easyPayRefundAttempt + if orderID := strings.TrimSpace(req.OrderID); orderID != "" { + params := cloneStringMap(base) + params["out_trade_no"] = orderID + attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID}) + } + if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" { + params := cloneStringMap(base) + params["trade_no"] = tradeNo + attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo}) + } + return attempts +} + +func cloneStringMap(in map[string]string) map[string]string { + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func isEasyPayRefundOrderNotFound(err error) bool { + if err == nil { + return false + } + msg := err.Error() + lower := strings.ToLower(msg) + return strings.Contains(msg, "订单编号不存在") || + strings.Contains(msg, "订单不存在") || + strings.Contains(lower, "order not found") || + strings.Contains(lower, "not exist") +} + +func parseEasyPayRefundResponse(status int, body []byte) error { + summary := summarizeEasyPayResponse(body) + if status < http.StatusOK || status >= http.StatusMultipleChoices { + return fmt.Errorf("easypay refund HTTP %d: %s", status, summary) + } + + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary) + } + + lower := strings.ToLower(trimmed) + if strings.HasPrefix(lower, "" + } + if len(summary) > maxEasypayErrorSummary { + return summary[:maxEasypayErrorSummary] + "..." + } + return summary } func (e *EasyPay) resolveCID(paymentType string) string { @@ -269,21 +410,34 @@ func (e *EasyPay) resolveCID(paymentType string) string { } func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) { + body, _, err := e.postRaw(ctx, endpoint, params) + return body, err +} + +func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) { form := url.Values{} for k, v := range params { form.Set(k, v) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { - return nil, err + return nil, 0, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := e.httpClient.Do(req) + client := e.httpClient + if client == nil { + client = &http.Client{Timeout: easypayHTTPTimeout} + } + resp, err := client.Do(req) if err != nil { - return nil, err + return nil, 0, err } defer func() { _ = resp.Body.Close() }() - return io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize)) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize)) + if err != nil { + return nil, resp.StatusCode, err + } + return body, resp.StatusCode, nil } func easyPaySign(params map[string]string, pkey string) string { diff --git a/backend/internal/payment/provider/easypay_refund_test.go b/backend/internal/payment/provider/easypay_refund_test.go new file mode 100644 index 00000000..9e0e4942 --- /dev/null +++ b/backend/internal/payment/provider/easypay_refund_test.go @@ -0,0 +1,196 @@ +package provider + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeEasyPayAPIBase(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want string + }{ + {input: "https://zpayz.cn", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"}, + {input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + if got := normalizeEasyPayAPIBase(tt.input); got != tt.want { + t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) { + t.Parallel() + + var gotPath string + var gotQuery url.Values + var gotForm url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQuery = r.URL.Query() + if err := r.ParseForm(); err != nil { + t.Errorf("ParseForm: %v", err) + } + gotForm = r.PostForm + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL+"/mapi.php") + resp, err := provider.Refund(context.Background(), payment.RefundRequest{ + TradeNo: "trade-123", + OrderID: "out-456", + Amount: "1.50", + }) + if err != nil { + t.Fatalf("Refund returned error: %v", err) + } + if resp == nil || resp.Status != payment.ProviderStatusSuccess { + t.Fatalf("Refund response = %+v, want success", resp) + } + if gotPath != "/api.php" { + t.Fatalf("refund path = %q, want /api.php", gotPath) + } + if gotQuery.Get("act") != "refund" { + t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act")) + } + for key, want := range map[string]string{ + "pid": "pid-1", + "key": "pkey-1", + "out_trade_no": "out-456", + "money": "1.50", + } { + if got := gotForm.Get(key); got != want { + t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm) + } + } + if got := gotForm.Get("trade_no"); got != "" { + t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm) + } +} + +func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) { + t.Parallel() + + var gotForms []url.Values + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api.php" { + t.Errorf("refund path = %q, want /api.php", r.URL.Path) + } + if r.URL.Query().Get("act") != "refund" { + t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act")) + } + if err := r.ParseForm(); err != nil { + t.Errorf("ParseForm: %v", err) + } + gotForms = append(gotForms, r.PostForm) + w.Header().Set("Content-Type", "application/json") + if len(gotForms) == 1 { + _, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`)) + return + } + _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL+"/mapi.php") + resp, err := provider.Refund(context.Background(), payment.RefundRequest{ + TradeNo: "trade-123", + OrderID: "out-456", + Amount: "1.50", + }) + if err != nil { + t.Fatalf("Refund returned error: %v", err) + } + if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" { + t.Fatalf("Refund response = %+v, want success with trade refund id", resp) + } + if len(gotForms) != 2 { + t.Fatalf("refund attempts = %d, want 2", len(gotForms)) + } + if got := gotForms[0].Get("out_trade_no"); got != "out-456" { + t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0]) + } + if got := gotForms[0].Get("trade_no"); got != "" { + t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0]) + } + if got := gotForms[1].Get("trade_no"); got != "trade-123" { + t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1]) + } + if got := gotForms[1].Get("out_trade_no"); got != "" { + t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1]) + } +} + +func TestEasyPayRefundResponseErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + want string + }{ + {name: "html response", statusCode: http.StatusOK, body: "bad config", want: "non-JSON response (HTTP 200): bad config"}, + {name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"}, + {name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"}, + {name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + provider := newTestEasyPay(t, server.URL) + _, err := provider.Refund(context.Background(), payment.RefundRequest{ + OrderID: "out-456", + Amount: "1.50", + }) + if err == nil { + t.Fatal("Refund returned nil error") + } + if !strings.Contains(err.Error(), tt.want) { + t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want) + } + }) + } +} + +func newTestEasyPay(t *testing.T, apiBase string) *EasyPay { + t.Helper() + + provider, err := NewEasyPay("test-instance", map[string]string{ + "pid": "pid-1", + "pkey": "pkey-1", + "apiBase": apiBase, + "notifyUrl": "https://example.com/notify", + "returnUrl": "https://example.com/return", + }) + if err != nil { + t.Fatalf("NewEasyPay: %v", err) + } + return provider +}