diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go index 5a2ecb46..a7bc4ba3 100644 --- a/backend/internal/handler/payment_handler_resume_test.go +++ b/backend/internal/handler/payment_handler_resume_test.go @@ -164,9 +164,8 @@ func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) { } func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) { - t.Parallel() - gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared") require.NoError(t, err) @@ -250,3 +249,120 @@ func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing require.Contains(t, resp.Data, "expires_at") require.Contains(t, resp.Data, "refund_amount") } + +func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + + db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + user, err := client.User.Create(). + SetEmail("public-resolve-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("public-resolve-mismatch-user"). + Save(context.Background()) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(100). + SetPayAmount(103). + SetFeeRate(0.03). + SetRechargeCode("PUBLIC-RESOLVE-MISMATCH"). + SetOutTradeNo("resolve-order-mismatch-no"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-public-resolve-mismatch"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(service.OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetPaidAt(time.Now()). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(context.Background()) + require.NoError(t, err) + + resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID + 999, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef")) + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/resolve", + bytes.NewBufferString(`{"resume_token":"`+token+`"}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.ResolveOrderPublicByResumeToken(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + Message string `json:"message"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason) +} + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + gin.SetMode(gin.TestMode) + + db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil) + h := NewPaymentHandler(paymentSvc, nil, nil) + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest( + http.MethodPost, + "/api/v1/payment/public/orders/verify", + bytes.NewBufferString(`{"out_trade_no":" "}`), + ) + ctx.Request.Header.Set("Content-Type", "application/json") + + h.VerifyOrderPublic(ctx) + + require.Equal(t, http.StatusBadRequest, recorder.Code) + + var resp struct { + Code int `json:"code"` + Reason string `json:"reason"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, http.StatusBadRequest, resp.Code) + require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason) +} diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index 57a4108f..e44bf2e7 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -20,7 +20,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return nil, fmt.Errorf("query provider instances: %w", err) } typeInstances := pcGroupByPaymentType(instances) - typeInstances = pcApplyEnabledVisibleMethodInstances(typeInstances, instances) + typeInstances = s.pcApplyEnabledVisibleMethodInstances(ctx, typeInstances, instances) resp := &MethodLimitsResponse{ Methods: make(map[string]MethodLimits, len(typeInstances)), } @@ -32,7 +32,7 @@ func (s *PaymentConfigService) GetAvailableMethodLimits(ctx context.Context) (*M return resp, nil } -func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { +func (s *PaymentConfigService) pcApplyEnabledVisibleMethodInstances(ctx context.Context, typeInstances map[string][]*dbent.PaymentProviderInstance, instances []*dbent.PaymentProviderInstance) map[string][]*dbent.PaymentProviderInstance { if len(typeInstances) == 0 { return typeInstances } @@ -44,11 +44,17 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym for _, method := range []string{payment.TypeAlipay, payment.TypeWxpay} { matching := filterEnabledVisibleMethodInstances(instances, method) - if len(matching) != 1 { + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil || providerKey == "" { delete(filtered, method) continue } - filtered[method] = []*dbent.PaymentProviderInstance{matching[0]} + selectedInstances := filterVisibleMethodInstancesByProviderKey(instances, method, providerKey) + if len(selectedInstances) == 0 { + delete(filtered, method) + continue + } + filtered[method] = selectedInstances } return filtered } diff --git a/backend/internal/service/payment_config_limits_test.go b/backend/internal/service/payment_config_limits_test.go index b3925583..12cd6866 100644 --- a/backend/internal/service/payment_config_limits_test.go +++ b/backend/internal/service/payment_config_limits_test.go @@ -301,65 +301,104 @@ func TestPcInstanceTypeLimits(t *testing.T) { }) } -func TestGetAvailableMethodLimitsHidesConflictingVisibleMethodProviders(t *testing.T) { - ctx := context.Background() - client := newPaymentConfigServiceTestClient(t) - - _, err := client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeAlipay). - SetName("Official Alipay"). - SetConfig("{}"). - SetSupportedTypes("alipay"). - SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create official alipay instance: %v", err) - } - _, err = client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeEasyPay). - SetName("EasyPay Alipay"). - SetConfig("{}"). - SetSupportedTypes("alipay"). - SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create easypay alipay instance: %v", err) - } - _, err = client.PaymentProviderInstance.Create(). - SetProviderKey(payment.TypeWxpay). - SetName("Official WeChat"). - SetConfig("{}"). - SetSupportedTypes("wxpay"). - SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). - SetEnabled(true). - Save(ctx) - if err != nil { - t.Fatalf("create official wxpay instance: %v", err) +func TestGetAvailableMethodLimitsUsesConfiguredVisibleMethodSource(t *testing.T) { + tests := []struct { + name string + sourceSetting string + wantAlipaySingleMin float64 + wantAlipaySingleMax float64 + wantGlobalMin float64 + wantGlobalMax float64 + }{ + { + name: "official source", + sourceSetting: VisibleMethodSourceOfficialAlipay, + wantAlipaySingleMin: 10, + wantAlipaySingleMax: 100, + wantGlobalMin: 10, + wantGlobalMax: 300, + }, + { + name: "easypay source", + sourceSetting: VisibleMethodSourceEasyPayAlipay, + wantAlipaySingleMin: 20, + wantAlipaySingleMax: 200, + wantGlobalMin: 20, + wantGlobalMax: 300, + }, } - svc := &PaymentConfigService{ - entClient: client, - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) - resp, err := svc.GetAvailableMethodLimits(ctx) - if err != nil { - t.Fatalf("GetAvailableMethodLimits returned error: %v", err) - } + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeAlipay). + SetName("Official Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":10,"singleMax":100}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetLimits(`{"alipay":{"singleMin":20,"singleMax":200}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay alipay instance: %v", err) + } + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("Official WeChat"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetLimits(`{"wxpay":{"singleMin":30,"singleMax":300}}`). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create official wxpay instance: %v", err) + } - if _, ok := resp.Methods[payment.TypeAlipay]; ok { - t.Fatalf("alipay should be hidden when multiple enabled providers claim it, got %v", resp.Methods[payment.TypeAlipay]) - } + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipaySource: tt.sourceSetting, + }, + }, + } - wxpayLimits, ok := resp.Methods[payment.TypeWxpay] - if !ok { - t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) - } - if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { - t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) - } - if resp.GlobalMin != 30 || resp.GlobalMax != 300 { - t.Fatalf("global range = (%v, %v), want (30, 300)", resp.GlobalMin, resp.GlobalMax) + resp, err := svc.GetAvailableMethodLimits(ctx) + if err != nil { + t.Fatalf("GetAvailableMethodLimits returned error: %v", err) + } + + alipayLimits, ok := resp.Methods[payment.TypeAlipay] + if !ok { + t.Fatalf("expected alipay limits to remain visible, got %v", resp.Methods) + } + if alipayLimits.SingleMin != tt.wantAlipaySingleMin || alipayLimits.SingleMax != tt.wantAlipaySingleMax { + t.Fatalf("alipay limits = %+v, want min=%v max=%v", alipayLimits, tt.wantAlipaySingleMin, tt.wantAlipaySingleMax) + } + + wxpayLimits, ok := resp.Methods[payment.TypeWxpay] + if !ok { + t.Fatalf("expected wxpay limits to remain visible, got %v", resp.Methods) + } + if wxpayLimits.SingleMin != 30 || wxpayLimits.SingleMax != 300 { + t.Fatalf("wxpay limits = %+v, want official-only min=30 max=300", wxpayLimits) + } + if resp.GlobalMin != tt.wantGlobalMin || resp.GlobalMax != tt.wantGlobalMax { + t.Fatalf("global range = (%v, %v), want (%v, %v)", resp.GlobalMin, resp.GlobalMax, tt.wantGlobalMin, tt.wantGlobalMax) + } + }) } } diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2c0f8206..51d5c7b6 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -4,9 +4,12 @@ package service import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -199,7 +202,7 @@ func TestJoinTypes(t *testing.T) { } } -func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *testing.T) { +func TestCreateProviderInstanceAllowsVisibleMethodProvidersFromDifferentSources(t *testing.T) { t.Parallel() ctx := context.Background() @@ -227,15 +230,14 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test _, err = svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ ProviderKey: "alipay", Name: "Official Alipay", - Config: map[string]string{"appId": "app-1"}, + Config: map[string]string{"appId": "app-1", "privateKey": "private-key"}, SupportedTypes: []string{"alipay"}, Enabled: true, }) - require.Error(t, err) - require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) + require.NoError(t, err) } -func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t *testing.T) { +func TestUpdateProviderInstanceAllowsEnablingVisibleMethodProviderFromDifferentSource(t *testing.T) { t.Parallel() ctx := context.Background() @@ -264,7 +266,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t candidate, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ ProviderKey: "wxpay", Name: "Official WeChat", - Config: map[string]string{"appId": "wx-app"}, + Config: validWxpayProviderConfig(t), SupportedTypes: []string{"wxpay"}, Enabled: false, }) @@ -273,8 +275,7 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t _, err = svc.UpdateProviderInstance(ctx, candidate.ID, UpdateProviderInstanceRequest{ Enabled: boolPtrValue(true), }) - require.Error(t, err) - require.Equal(t, "PAYMENT_PROVIDER_CONFLICT", infraerrors.Reason(err)) + require.NoError(t, err) } func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { @@ -317,3 +318,25 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { func boolPtrValue(v bool) *bool { return &v } + +func validWxpayProviderConfig(t *testing.T) map[string]string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return map[string]string{ + "appId": "wx-app-test", + "mchId": "mch-test", + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": "12345678901234567890123456789012", + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-test", + "certSerial": "cert-serial-test", + } +} diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index ffb63066..f14dc55d 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -234,6 +234,10 @@ func paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, current // if a payment was made, and processes it if so. This handles the case where // the provider's notify callback was missed (e.g. EasyPay popup mode). func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) @@ -261,6 +265,10 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo // triggering any upstream reconciliation. Signed resume-token recovery is the // only public recovery path allowed to query upstream state. func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { + outTradeNo, err := normalizeOrderLookupOutTradeNo(outTradeNo) + if err != nil { + return nil, err + } o, err := s.entClient.PaymentOrder.Query(). Where(paymentorder.OutTradeNo(outTradeNo)). Only(ctx) @@ -270,6 +278,27 @@ func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo strin return o, nil } +func normalizeOrderLookupOutTradeNo(raw string) (string, error) { + outTradeNo := strings.TrimSpace(raw) + if outTradeNo == "" { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is required") + } + if len(outTradeNo) > 64 { + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") + } + for _, ch := range outTradeNo { + switch { + case ch >= 'a' && ch <= 'z': + case ch >= 'A' && ch <= 'Z': + case ch >= '0' && ch <= '9': + case ch == '_' || ch == '-': + default: + return "", infraerrors.BadRequest("INVALID_OUT_TRADE_NO", "out_trade_no is invalid") + } + } + return outTradeNo, nil +} + func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { now := time.Now() orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index 23371cfd..2d7412e0 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "strings" "testing" "time" @@ -91,6 +92,8 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { } func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") + svc := newWeChatPaymentOAuthTestService(map[string]string{ SettingKeyWeChatConnectEnabled: "true", SettingKeyWeChatConnectAppID: "wx123456", @@ -198,6 +201,44 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testin } } +func TestMaybeBuildWeChatOAuthRequiredResponseFallsBackToConfiguredLegacySigningKey(t *testing.T) { + svc := &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }}, + // Legacy stable signing key remains available for no-config upgrade compatibility. + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ + Amount: 12.5, + PaymentType: payment.TypeWxpay, + IsWeChatBrowser: true, + SrcURL: "https://merchant.example/payment?from=wechat", + OrderType: payment.OrderTypeBalance, + }, 12.5, 12.88, 0.03) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if resp == nil { + t.Fatal("expected oauth-required response, got nil") + } + if resp.ResultType != payment.CreatePaymentResultOAuthRequired { + t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultOAuthRequired) + } + if resp.OAuth == nil || strings.TrimSpace(resp.OAuth.AuthorizeURL) == "" { + t.Fatalf("expected oauth redirect payload, got %+v", resp.OAuth) + } +} + func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { svc := newWeChatPaymentOAuthTestService(map[string]string{ SettingKeyWeChatConnectEnabled: "true", diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go index 05626aa6..1ff061e8 100644 --- a/backend/internal/service/payment_resume_lookup.go +++ b/backend/internal/service/payment_resume_lookup.go @@ -6,6 +6,7 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { @@ -16,10 +17,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) if err != nil { + if dbent.IsNotFound(err) { + return nil, infraerrors.NotFound("NOT_FOUND", "order not found") + } return nil, fmt.Errorf("get order by resume token: %w", err) } if claims.UserID > 0 && order.UserID != claims.UserID { - return nil, fmt.Errorf("resume token user mismatch") + return nil, invalidResumeTokenMatchError() } snapshot := psOrderProviderSnapshot(order) orderProviderInstanceID := strings.TrimSpace(psStringValue(order.ProviderInstanceID)) @@ -33,13 +37,13 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token } } if claims.ProviderInstanceID != "" && orderProviderInstanceID != claims.ProviderInstanceID { - return nil, fmt.Errorf("resume token provider instance mismatch") + return nil, invalidResumeTokenMatchError() } - if claims.ProviderKey != "" && orderProviderKey != claims.ProviderKey { - return nil, fmt.Errorf("resume token provider key mismatch") + if claims.ProviderKey != "" && !strings.EqualFold(orderProviderKey, claims.ProviderKey) { + return nil, invalidResumeTokenMatchError() } - if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { - return nil, fmt.Errorf("resume token payment type mismatch") + if claims.PaymentType != "" && NormalizeVisibleMethod(order.PaymentType) != NormalizeVisibleMethod(claims.PaymentType) { + return nil, invalidResumeTokenMatchError() } if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { result := s.checkPaid(ctx, order) @@ -54,6 +58,10 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token return order, nil } +func invalidResumeTokenMatchError() error { + return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token does not match the payment order") +} + func (s *PaymentService) ParseWeChatPaymentResumeToken(token string) (*WeChatPaymentResumeClaims, error) { return s.paymentResume().ParseWeChatPaymentResumeToken(strings.TrimSpace(token)) } diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go index 946e7aa2..a7b5b737 100644 --- a/backend/internal/service/payment_resume_lookup_test.go +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -143,7 +144,7 @@ func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { _, err = svc.GetPublicOrderByResumeToken(ctx, token) require.Error(t, err) - require.Contains(t, err.Error(), "resume token") + require.Equal(t, "INVALID_RESUME_TOKEN", infraerrors.Reason(err)) } func TestGetPublicOrderByResumeTokenUsesSnapshotAuthorityWhenColumnsDiffer(t *testing.T) { @@ -302,3 +303,13 @@ func TestVerifyOrderPublicDoesNotCheckUpstreamForPendingOrder(t *testing.T) { require.Equal(t, order.ID, got.ID) require.Equal(t, 0, provider.queryCount) } + +func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) { + svc := &PaymentService{ + entClient: newPaymentConfigServiceTestClient(t), + } + + _, err := svc.VerifyOrderPublic(context.Background(), " ") + require.Error(t, err) + require.Equal(t, "INVALID_OUT_TRADE_NO", infraerrors.Reason(err)) +} diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 438aa59f..9ae62fde 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "crypto/hmac" "crypto/sha256" @@ -68,6 +69,7 @@ type WeChatPaymentResumeClaims struct { type PaymentResumeService struct { signingKey []byte + verifyKeys [][]byte } type visibleMethodLoadBalancer struct { @@ -75,8 +77,29 @@ type visibleMethodLoadBalancer struct { configService *PaymentConfigService } -func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { - return &PaymentResumeService{signingKey: signingKey} +func NewPaymentResumeService(signingKey []byte, verifyFallbacks ...[]byte) *PaymentResumeService { + svc := &PaymentResumeService{} + if len(signingKey) > 0 { + svc.signingKey = append([]byte(nil), signingKey...) + svc.verifyKeys = append(svc.verifyKeys, svc.signingKey) + } + for _, fallback := range verifyFallbacks { + if len(fallback) == 0 { + continue + } + cloned := append([]byte(nil), fallback...) + duplicate := false + for _, existing := range svc.verifyKeys { + if bytes.Equal(existing, cloned) { + duplicate = true + break + } + } + if !duplicate { + svc.verifyKeys = append(svc.verifyKeys, cloned) + } + } + return svc } func (s *PaymentResumeService) isSigningConfigured() bool { @@ -410,7 +433,7 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { if len(parts) != 2 || parts[0] == "" || parts[1] == "" { return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") } - if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + if !s.verifySignature(parts[0], parts[1]) { return infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") } payload, err := base64.RawURLEncoding.DecodeString(parts[0]) @@ -420,6 +443,18 @@ func (s *PaymentResumeService) parseSignedToken(token string, dest any) error { return json.Unmarshal(payload, dest) } +func (s *PaymentResumeService) verifySignature(payload string, signature string) bool { + if s == nil { + return false + } + for _, key := range s.verifyKeys { + if hmac.Equal([]byte(signature), []byte(signPaymentResumePayload(payload, key))) { + return true + } + } + return false +} + func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { if expiresAt <= 0 { return nil @@ -431,7 +466,11 @@ func validatePaymentResumeExpiry(expiresAt int64, code, message string) error { } func (s *PaymentResumeService) sign(payload string) string { - mac := hmac.New(sha256.New, s.signingKey) + return signPaymentResumePayload(payload, s.signingKey) +} + +func signPaymentResumePayload(payload string, key []byte) string { + mac := hmac.New(sha256.New, key) _, _ = mac.Write([]byte(payload)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index e19e0b99..9e756971 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -334,6 +334,59 @@ func TestParseWeChatPaymentResumeTokenRejectsExpiredToken(t *testing.T) { } } +func TestPaymentServiceParseWeChatPaymentResumeTokenUsesExplicitSigningKey(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + token, err := NewPaymentResumeService([]byte("explicit-payment-resume-signing-key")).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-explicit-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: []byte("0123456789abcdef0123456789abcdef"), + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-explicit-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-explicit-key") + } +} + +func TestPaymentServiceParseWeChatPaymentResumeTokenAcceptsLegacyEncryptionKeyDuringMigration(t *testing.T) { + t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "explicit-payment-resume-signing-key") + + legacyKey := []byte("0123456789abcdef0123456789abcdef") + token, err := NewPaymentResumeService(legacyKey).CreateWeChatPaymentResumeToken(WeChatPaymentResumeClaims{ + OpenID: "openid-legacy-key", + PaymentType: payment.TypeWxpay, + }) + if err != nil { + t.Fatalf("CreateWeChatPaymentResumeToken returned error: %v", err) + } + + svc := &PaymentService{ + configService: &PaymentConfigService{ + encryptionKey: legacyKey, + }, + } + + claims, err := svc.ParseWeChatPaymentResumeToken(token) + if err != nil { + t.Fatalf("ParseWeChatPaymentResumeToken returned error: %v", err) + } + if claims.OpenID != "openid-legacy-key" { + t.Fatalf("openid = %q, want %q", claims.OpenID, "openid-legacy-key") + } +} + func TestNormalizeVisibleMethodSource(t *testing.T) { t.Parallel() @@ -424,14 +477,14 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl t.Parallel() tests := []struct { - name string - method payment.PaymentType - officialName string - officialTypes string - easyPayName string - easyPayTypes string - sourceSetting string - wantProvider string + name string + method payment.PaymentType + officialName string + officialTypes string + easyPayName string + easyPayTypes string + sourceSetting string + wantProvider string }{ { name: "alipay uses official source", @@ -487,7 +540,7 @@ func TestVisibleMethodLoadBalancerUsesConfiguredSourceWhenMultipleProvidersEnabl officialProviderKey = payment.TypeWxpay } - _, err = client.PaymentProviderInstance.Create(). + _, err := client.PaymentProviderInstance.Create(). SetProviderKey(officialProviderKey). SetName(tt.officialName). SetConfig("{}"). diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 73bbb256..159f97d3 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -1,10 +1,14 @@ package service import ( + "bytes" "context" + "encoding/hex" "fmt" "log/slog" "math/rand/v2" + "os" + "strings" "sync" "time" @@ -44,6 +48,8 @@ const ( orderIDPrefix = "sub2_" ) +const paymentResumeSigningKeyEnv = "PAYMENT_RESUME_SIGNING_KEY" + // --- Types --- // generateOutTradeNo creates a unique external order ID for payment providers. @@ -179,7 +185,7 @@ type PaymentService struct { func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} - svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + svc.resumeService = psNewPaymentResumeService(configService) return svc } @@ -259,16 +265,54 @@ func (s *PaymentService) paymentResume() *PaymentResumeService { if s.resumeService != nil { return s.resumeService } - return NewPaymentResumeService(psResumeSigningKey(s.configService)) + return psNewPaymentResumeService(s.configService) +} + +func psNewPaymentResumeService(configService *PaymentConfigService) *PaymentResumeService { + signingKey, verifyFallbacks := psResumeSigningKeys(configService) + return NewPaymentResumeService(signingKey, verifyFallbacks...) } func psResumeSigningKey(configService *PaymentConfigService) []byte { + signingKey, _ := psResumeSigningKeys(configService) + return signingKey +} + +func psResumeSigningKeys(configService *PaymentConfigService) ([]byte, [][]byte) { + signingKey := parsePaymentResumeSigningKey(os.Getenv(paymentResumeSigningKeyEnv)) + legacyKey := psResumeLegacyVerificationKey(configService) + if len(signingKey) == 0 { + if len(legacyKey) == 0 { + return nil, nil + } + return legacyKey, nil + } + if len(legacyKey) == 0 || bytes.Equal(legacyKey, signingKey) { + return signingKey, nil + } + return signingKey, [][]byte{legacyKey} +} + +func psResumeLegacyVerificationKey(configService *PaymentConfigService) []byte { if configService == nil { return nil } return configService.encryptionKey } +func parsePaymentResumeSigningKey(raw string) []byte { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + if len(raw) >= 64 && len(raw)%2 == 0 { + if decoded, err := hex.DecodeString(raw); err == nil && len(decoded) > 0 { + return decoded + } + } + return []byte(raw) +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/payment_visible_method_instances.go b/backend/internal/service/payment_visible_method_instances.go index 39358522..86ea5ead 100644 --- a/backend/internal/service/payment_visible_method_instances.go +++ b/backend/internal/service/payment_visible_method_instances.go @@ -82,6 +82,41 @@ func filterEnabledVisibleMethodInstances(instances []*dbent.PaymentProviderInsta return filtered } +func filterVisibleMethodInstancesByProviderKey(instances []*dbent.PaymentProviderInstance, method string, providerKey string) []*dbent.PaymentProviderInstance { + filtered := make([]*dbent.PaymentProviderInstance, 0, len(instances)) + for _, inst := range instances { + if !providerSupportsVisibleMethod(inst, method) { + continue + } + if !strings.EqualFold(strings.TrimSpace(inst.ProviderKey), strings.TrimSpace(providerKey)) { + continue + } + filtered = append(filtered, inst) + } + return filtered +} + +func distinctVisibleMethodProviderKeys(instances []*dbent.PaymentProviderInstance) []string { + seen := make(map[string]struct{}, len(instances)) + keys := make([]string, 0, len(instances)) + for _, inst := range instances { + if inst == nil { + continue + } + key := strings.TrimSpace(inst.ProviderKey) + if key == "" { + continue + } + normalized := strings.ToLower(key) + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + keys = append(keys, key) + } + return keys +} + func selectVisibleMethodInstanceByProviderKey(instances []*dbent.PaymentProviderInstance, providerKey string) *dbent.PaymentProviderInstance { providerKey = strings.TrimSpace(providerKey) if providerKey == "" { @@ -117,32 +152,10 @@ func (s *PaymentConfigService) validateVisibleMethodEnablementConflicts( supportedTypes string, enabled bool, ) error { - if s == nil || s.entClient == nil || !enabled { - return nil - } - - claimedMethods := enabledVisibleMethodsForProvider(providerKey, supportedTypes) - if len(claimedMethods) == 0 { - return nil - } - - query := s.entClient.PaymentProviderInstance.Query(). - Where(paymentproviderinstance.EnabledEQ(true)) - if excludeID > 0 { - query = query.Where(paymentproviderinstance.IDNEQ(excludeID)) - } - instances, err := query.All(ctx) - if err != nil { - return fmt.Errorf("query enabled payment providers: %w", err) - } - - for _, method := range claimedMethods { - for _, inst := range instances { - if providerSupportsVisibleMethod(inst, method) { - return buildPaymentProviderConflictError(method, inst) - } - } - } + // Visible methods are selected by configured source (official/easypay), + // so multiple enabled providers can intentionally claim the same user-facing + // method. Order creation and limits will route through the configured source. + _, _, _, _, _ = ctx, excludeID, providerKey, supportedTypes, enabled return nil } @@ -172,6 +185,32 @@ func (s *PaymentConfigService) resolveVisibleMethodSourceProviderKey(ctx context return providerKey, nil } +func (s *PaymentConfigService) resolveVisibleMethodProviderKey( + ctx context.Context, + method string, + matching []*dbent.PaymentProviderInstance, +) (string, error) { + switch providerKeys := distinctVisibleMethodProviderKeys(matching); len(providerKeys) { + case 0: + return "", nil + case 1: + return strings.TrimSpace(providerKeys[0]), nil + default: + providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) + if err != nil { + return "", err + } + selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) + if selected == nil { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source has no enabled provider instance", method), + ) + } + return strings.TrimSpace(selected.ProviderKey), nil + } +} + func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( ctx context.Context, method string, @@ -194,23 +233,9 @@ func (s *PaymentConfigService) resolveEnabledVisibleMethodInstance( } matching := filterEnabledVisibleMethodInstances(instances, method) - switch len(matching) { - case 0: - return nil, nil - case 1: - return matching[0], nil - default: - providerKey, err := s.resolveVisibleMethodSourceProviderKey(ctx, method) - if err != nil { - return nil, err - } - selected := selectVisibleMethodInstanceByProviderKey(matching, providerKey) - if selected == nil { - return nil, infraerrors.BadRequest( - "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", - fmt.Sprintf("%s source has no enabled provider instance", method), - ) - } - return selected, nil + providerKey, err := s.resolveVisibleMethodProviderKey(ctx, method, matching) + if err != nil { + return nil, err } + return selectVisibleMethodInstanceByProviderKey(matching, providerKey), nil }