Tighten WeChat payment resume flow

This commit is contained in:
IanShaw027
2026-04-21 00:33:23 +08:00
parent 1521d50399
commit 55e8dd550a
15 changed files with 514 additions and 98 deletions

View File

@@ -435,24 +435,34 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
scope = strings.TrimSpace(tokenResp.Scope)
}
resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
OpenID: openid,
PaymentType: paymentContext.PaymentType,
Amount: paymentContext.Amount,
OrderType: paymentContext.OrderType,
PlanID: paymentContext.PlanID,
RedirectTo: redirectTo,
Scope: scope,
})
if err != nil {
redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
return
}
fragment := url.Values{}
fragment.Set("openid", openid)
fragment.Set("state", state)
fragment.Set("scope", scope)
fragment.Set("payment_type", paymentContext.PaymentType)
if paymentContext.Amount != "" {
fragment.Set("amount", paymentContext.Amount)
}
if paymentContext.OrderType != "" {
fragment.Set("order_type", paymentContext.OrderType)
}
if paymentContext.PlanID > 0 {
fragment.Set("plan_id", strconv.FormatInt(paymentContext.PlanID, 10))
}
fragment.Set("wechat_resume_token", resumeToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
}
func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
key, err := payment.ProvideEncryptionKey(h.cfg)
if err != nil {
return service.NewPaymentResumeService(nil)
}
return service.NewPaymentResumeService([]byte(key))
}
type completeWeChatOAuthRequest struct {
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`

View File

@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -175,6 +176,66 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
require.Zero(t, count)
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
return
}
http.NotFound(w, r)
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
c.Request = req
handler.WeChatPaymentOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
require.NotEmpty(t, fragment.Get("wechat_resume_token"))
require.Empty(t, fragment.Get("openid"))
require.Empty(t, fragment.Get("payment_type"))
require.Empty(t, fragment.Get("amount"))
require.Empty(t, fragment.Get("order_type"))
require.Empty(t, fragment.Get("plan_id"))
claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
require.NoError(t, err)
require.Equal(t, "openid-123", claims.OpenID)
require.Equal(t, payment.TypeWxpay, claims.PaymentType)
require.Equal(t, "12.5", claims.Amount)
require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
require.EqualValues(t, 7, claims.PlanID)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
}
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string

View File

@@ -1,9 +1,12 @@
package handler
import (
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -202,14 +205,15 @@ func (h *PaymentHandler) GetLimits(c *gin.Context) {
// CreateOrderRequest is the request body for creating a payment order.
type CreateOrderRequest struct {
Amount float64 `json:"amount"`
PaymentType string `json:"payment_type" binding:"required"`
OpenID string `json:"openid"`
ReturnURL string `json:"return_url"`
PaymentSource string `json:"payment_source"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
IsMobile *bool `json:"is_mobile,omitempty"`
Amount float64 `json:"amount"`
PaymentType string `json:"payment_type" binding:"required"`
OpenID string `json:"openid"`
WechatResumeToken string `json:"wechat_resume_token"`
ReturnURL string `json:"return_url"`
PaymentSource string `json:"payment_source"`
OrderType string `json:"order_type"`
PlanID int64 `json:"plan_id"`
IsMobile *bool `json:"is_mobile,omitempty"`
}
// CreateOrder creates a new payment order.
@@ -225,6 +229,17 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.WechatResumeToken) != "" {
claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
response.ErrorFrom(c, err)
return
}
}
mobile := isMobile(c)
if req.IsMobile != nil {
@@ -253,6 +268,44 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) {
response.Success(c, result)
}
func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
if req == nil || claims == nil {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
}
openid := strings.TrimSpace(claims.OpenID)
if openid == "" {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
}
paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
if paymentType == "" {
paymentType = payment.TypeWxpay
}
if req.PaymentType != "" {
requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
if requestPaymentType != "" && requestPaymentType != paymentType {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
}
}
req.PaymentType = paymentType
req.OpenID = openid
if strings.TrimSpace(claims.Amount) != "" {
amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
if err != nil || amount <= 0 {
return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
}
req.Amount = amount
}
if claims.OrderType != "" {
req.OrderType = claims.OrderType
}
if claims.PlanID > 0 {
req.PlanID = claims.PlanID
}
return nil
}
// GetMyOrders returns the authenticated user's orders.
// GET /api/v1/payment/orders/my
func (h *PaymentHandler) GetMyOrders(c *gin.Context) {

View File

@@ -0,0 +1,61 @@
//go:build unit
package handler
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
t.Parallel()
req := CreateOrderRequest{
Amount: 0,
PaymentType: payment.TypeWxpay,
OrderType: payment.OrderTypeBalance,
}
err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeSubscription,
PlanID: 7,
})
if err != nil {
t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
}
if req.OpenID != "openid-123" {
t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
}
if req.Amount != 12.5 {
t.Fatalf("amount = %v, want 12.5", req.Amount)
}
if req.OrderType != payment.OrderTypeSubscription {
t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
}
if req.PlanID != 7 {
t.Fatalf("plan_id = %d, want 7", req.PlanID)
}
}
func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
t.Parallel()
req := CreateOrderRequest{
PaymentType: payment.TypeAlipay,
}
err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
OpenID: "openid-123",
PaymentType: payment.TypeWxpay,
Amount: "12.50",
OrderType: payment.OrderTypeBalance,
})
if err == nil {
t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
}
}