fix(review): harden payment, oauth, and migration paths

This commit is contained in:
IanShaw027
2026-04-22 10:26:22 +08:00
parent 7fbd5177c2
commit c229f33e9e
33 changed files with 704 additions and 79 deletions

View File

@@ -1202,7 +1202,7 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.redirect_url", "")
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
viper.SetDefault("linuxdo_connect.use_pkce", false)
viper.SetDefault("linuxdo_connect.use_pkce", true)
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
@@ -1222,7 +1222,7 @@ func setDefaults() {
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.use_pkce", true)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)

View File

@@ -937,7 +937,19 @@ func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
Value: "",
Path: oauthBindAccessTokenCookiePath,
MaxAge: -1,
HttpOnly: false,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthBindAccessTokenCookieName,
Value: url.QueryEscape(strings.TrimSpace(token)),
Path: oauthBindAccessTokenCookiePath,
MaxAge: linuxDoOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
@@ -1021,6 +1033,26 @@ func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (strin
return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
}
func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
const bearerPrefix = "Bearer "
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
if token == "" {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
c.Status(http.StatusNoContent)
c.Writer.WriteHeaderNow()
}
func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return &subject.UserID, nil

View File

@@ -5,6 +5,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@@ -226,6 +227,27 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
require.Equal(t, -1, accessTokenCookie.MaxAge)
}
func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
req.Header.Set("Authorization", "Bearer access-token-value")
c.Request = req
handler.PrepareOAuthBindAccessTokenCookie(c)
require.Equal(t, http.StatusNoContent, recorder.Code)
accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
require.NotNil(t, accessTokenCookie)
require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
require.True(t, accessTokenCookie.HttpOnly)
require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
}
func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {

View File

@@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
if cfg.Totp.EncryptionKey == "" {
if cfg == nil {
slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
return nil, nil
}
keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
if keyHex == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
// Reject auto-generated TOTP keys for payment signing.
// They change across restarts/instances and can silently break resume-token flows.
if !cfg.Totp.EncryptionKeyConfigured {
slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
return nil, nil
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}

View File

@@ -0,0 +1,62 @@
package payment
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: strings.Repeat("a", 64),
EncryptionKeyConfigured: false,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 0 {
t.Fatalf("encryption key len = %d, want 0", len(key))
}
}
func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
EncryptionKeyConfigured: true,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 32 {
t.Fatalf("encryption key len = %d, want 32", len(key))
}
}
func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "abcd",
EncryptionKeyConfigured: true,
},
}
_, err := ProvideEncryptionKey(cfg)
if err == nil {
t.Fatal("expected error for invalid key length")
}
}

View File

@@ -164,6 +164,7 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")

View File

@@ -80,21 +80,25 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo
})
return err
}
// Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount).
// Also skip if paid is NaN/Inf (malformed provider data).
if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) {
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
if !isValidProviderAmount(paid) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", pk, map[string]any{
"expected": o.PayAmount,
"paid": paid,
"tradeNo": tradeNo,
})
return fmt.Errorf("invalid paid amount from provider: %v", paid)
}
// Use order's expected amount when provider didn't report one
if paid <= 0 || math.IsNaN(paid) || math.IsInf(paid, 0) {
paid = o.PayAmount
if math.Abs(paid-o.PayAmount) > amountToleranceCNY {
s.writeAuditLog(ctx, o.ID, "PAYMENT_AMOUNT_MISMATCH", pk, map[string]any{"expected": o.PayAmount, "paid": paid, "tradeNo": tradeNo})
return fmt.Errorf("amount mismatch: expected %.2f, got %.2f", o.PayAmount, paid)
}
return s.toPaid(ctx, o, tradeNo, paid, pk)
}
func isValidProviderAmount(amount float64) bool {
return amount > 0 && !math.IsNaN(amount) && !math.IsInf(amount, 0)
}
func validateProviderNotificationMetadata(order *dbent.PaymentOrder, providerKey string, metadata map[string]string) error {
return validateProviderSnapshotMetadata(order, providerKey, metadata)
}

View File

@@ -5,6 +5,7 @@ package service
import (
"context"
"errors"
"math"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -322,6 +323,16 @@ func TestParseLegacyPaymentOrderID(t *testing.T) {
assert.False(t, ok)
}
func TestIsValidProviderAmount(t *testing.T) {
t.Parallel()
assert.True(t, isValidProviderAmount(0.01))
assert.False(t, isValidProviderAmount(0))
assert.False(t, isValidProviderAmount(-1))
assert.False(t, isValidProviderAmount(math.NaN()))
assert.False(t, isValidProviderAmount(math.Inf(1)))
}
func TestValidateProviderNotificationMetadataRejectsAlipaySnapshotMismatch(t *testing.T) {
t.Parallel()

View File

@@ -139,6 +139,10 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
tm = defaultOrderTimeoutMin
}
exp := time.Now().Add(time.Duration(tm) * time.Minute)
outTradeNo, err := s.allocateOutTradeNo(ctx, tx)
if err != nil {
return nil, err
}
providerSnapshot := buildPaymentOrderProviderSnapshot(sel, req)
selectedInstanceID := ""
selectedProviderKey := ""
@@ -155,7 +159,7 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
SetPayAmount(payAmount).
SetFeeRate(feeRate).
SetRechargeCode("").
SetOutTradeNo(generateOutTradeNo()).
SetOutTradeNo(outTradeNo).
SetPaymentType(req.PaymentType).
SetPaymentTradeNo("").
SetOrderType(req.OrderType).
@@ -193,6 +197,21 @@ func (s *PaymentService) createOrderInTx(ctx context.Context, req CreateOrderReq
return order, nil
}
func (s *PaymentService) allocateOutTradeNo(ctx context.Context, tx *dbent.Tx) (string, error) {
const maxAttempts = 5
for attempt := 0; attempt < maxAttempts; attempt++ {
candidate := generateOutTradeNo()
exists, err := tx.PaymentOrder.Query().Where(paymentorder.OutTradeNo(candidate)).Exist(ctx)
if err != nil {
return "", fmt.Errorf("check out_trade_no uniqueness: %w", err)
}
if !exists {
return candidate, nil
}
}
return "", fmt.Errorf("generate unique out_trade_no: exhausted %d attempts", maxAttempts)
}
func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, userID int64, max int) error {
if max <= 0 {
max = defaultMaxPendingOrders
@@ -366,7 +385,10 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
}
resumeToken := ""
if resume := s.paymentResume(); resume != nil {
if resume.isSigningConfigured() {
if canonicalReturnURL != "" {
if err := resume.ensureSigningKey(); err != nil {
return nil, err
}
resumeToken, err = resume.CreateToken(ResumeTokenClaims{
OrderID: order.ID,
UserID: order.UserID,
@@ -482,6 +504,9 @@ func (s *PaymentService) buildWeChatOAuthRequiredResponse(ctx context.Context, r
if err != nil {
return nil, err
}
if err := s.paymentResume().ensureSigningKey(); err != nil {
return nil, err
}
authorizeURL, err := buildWeChatPaymentOAuthStartURL(req, "snsapi_base")
if err != nil {

View File

@@ -150,6 +150,16 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if !isValidProviderAmount(resp.Amount) {
s.writeAuditLog(ctx, o.ID, "PAYMENT_INVALID_AMOUNT", prov.ProviderKey(), map[string]any{
"expected": o.PayAmount,
"paid": resp.Amount,
"tradeNo": resp.TradeNo,
"queryRef": queryRef,
})
slog.Warn("query upstream returned invalid paid amount", "orderID", o.ID, "queryRef", queryRef, "paid", resp.Amount)
return ""
}
notificationTradeNo := o.PaymentTradeNo
if upstreamTradeNo := strings.TrimSpace(resp.TradeNo); paymentOrderShouldPersistUpstreamTradeNo(queryRef, upstreamTradeNo, notificationTradeNo) {
if _, updateErr := s.entClient.PaymentOrder.Update().

View File

@@ -234,6 +234,97 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) {
require.Equal(t, user.ID, redeemRepo.useCalls[0].userID)
}
func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
user, err := client.User.Create().
SetEmail("checkpaid-zero-amount@example.com").
SetPasswordHash("hash").
SetUsername("checkpaid-zero-amount-user").
Save(ctx)
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(88).
SetFeeRate(0).
SetRechargeCode("CHECKPAID-ZERO-AMOUNT").
SetOutTradeNo("sub2_checkpaid_zero_amount").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("").
SetOrderType(payment.OrderTypeBalance).
SetStatus(OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(ctx)
require.NoError(t, err)
userRepo := &mockUserRepo{
getByIDUser: &User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
Balance: 0,
},
}
redeemRepo := &paymentOrderLifecycleRedeemRepo{
codesByCode: map[string]*RedeemCode{
order.RechargeCode: {
ID: 1,
Code: order.RechargeCode,
Type: RedeemTypeBalance,
Value: order.Amount,
Status: StatusUnused,
},
},
}
redeemService := NewRedeemService(
redeemRepo,
userRepo,
nil,
nil,
nil,
client,
nil,
)
registry := payment.NewRegistry()
provider := &paymentOrderLifecycleQueryProvider{
resp: &payment.QueryOrderResponse{
TradeNo: "upstream-trade-zero",
Status: payment.ProviderStatusPaid,
Amount: 0,
},
}
registry.Register(provider)
svc := &PaymentService{
entClient: client,
registry: registry,
redeemService: redeemService,
userRepo: userRepo,
providersLoaded: true,
}
got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
require.NoError(t, err)
require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
require.Equal(t, OrderStatusPending, got.Status)
require.Empty(t, got.PaymentTradeNo)
reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
require.NoError(t, err)
require.Equal(t, OrderStatusPending, reloaded.Status)
require.Empty(t, reloaded.PaymentTradeNo)
require.Equal(t, 0.0, userRepo.getByIDUser.Balance)
require.Empty(t, redeemRepo.useCalls)
}
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)

View File

@@ -159,6 +159,45 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseRequiresResumeSigningKey(t *testing.T) {
t.Parallel()
svc := &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{
SettingKeyWeChatConnectEnabled: "true",
SettingKeyWeChatConnectAppID: "wx123456",
SettingKeyWeChatConnectAppSecret: "wechat-secret",
SettingKeyWeChatConnectMode: "mp",
SettingKeyWeChatConnectScopes: "snsapi_base",
SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
}},
// Intentionally missing payment resume signing key.
encryptionKey: nil,
},
}
resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{
Amount: 12.5,
PaymentType: payment.TypeWxpay,
IsWeChatBrowser: true,
SrcURL: "https://merchant.example/payment?from=wechat",
OrderType: payment.OrderTypeBalance,
}, 12.5, 12.88, 0.03)
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if err == nil {
t.Fatal("expected error, got nil")
}
appErr := infraerrors.FromError(err)
if appErr.Reason != "PAYMENT_RESUME_NOT_CONFIGURED" {
t.Fatalf("reason = %q, want %q", appErr.Reason, "PAYMENT_RESUME_NOT_CONFIGURED")
}
}
func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) {
svc := newWeChatPaymentOAuthTestService(map[string]string{
SettingKeyWeChatConnectEnabled: "true",
@@ -189,7 +228,8 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t
func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService {
return &PaymentService{
configService: &PaymentConfigService{
settingRepo: &paymentConfigSettingRepoStub{values: values},
settingRepo: &paymentConfigSettingRepoStub{values: values},
encryptionKey: []byte("0123456789abcdef0123456789abcdef"),
},
}
}