From 09c8380b3df4b01ec7625f1dd1b680bde6f7f5db Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 12 Feb 2026 12:04:13 +0800 Subject: [PATCH] =?UTF-8?q?fix(repository):=20=E4=BF=AE=E5=A4=8D=20JWT=20?= =?UTF-8?q?=E5=AF=86=E9=92=A5=E5=BC=95=E5=AF=BC=E5=86=B2=E7=AA=81=E4=B8=80?= =?UTF-8?q?=E8=87=B4=E6=80=A7=E4=B8=8E=E5=B9=B6=E5=8F=91=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E7=AB=9E=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repository/security_secret_bootstrap.go | 89 ++++++++++++++++--- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/backend/internal/repository/security_secret_bootstrap.go b/backend/internal/repository/security_secret_bootstrap.go index 85fdbf08..e773c238 100644 --- a/backend/internal/repository/security_secret_bootstrap.go +++ b/backend/internal/repository/security_secret_bootstrap.go @@ -3,17 +3,24 @@ package repository import ( "context" "crypto/rand" + "database/sql" "encoding/hex" + "errors" "fmt" "log" "strings" + "time" "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/internal/config" ) -const securitySecretKeyJWT = "jwt_secret" +const ( + securitySecretKeyJWT = "jwt_secret" + securitySecretReadRetryMax = 5 + securitySecretReadRetryWait = 10 * time.Millisecond +) var readRandomBytes = rand.Read @@ -27,9 +34,14 @@ func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) if cfg.JWT.Secret != "" { - if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil { + storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret) + if err != nil { return fmt.Errorf("persist jwt secret: %w", err) } + if storedSecret != cfg.JWT.Secret { + log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.") + } + cfg.JWT.Secret = storedSecret return nil } @@ -69,10 +81,12 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, OnConflictColumns(securitysecret.FieldKey). DoNothing(). Exec(ctx); err != nil { - return "", false, err + if !isSQLNoRowsError(err) { + return "", false, err + } } - stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + stored, err := querySecuritySecretWithRetry(ctx, client, key) if err != nil { return "", false, err } @@ -83,17 +97,72 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, return value, value == generated, nil } -func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error { +func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) { value = strings.TrimSpace(value) if len([]byte(value)) < 32 { - return fmt.Errorf("secret %q must be at least 32 bytes", key) + return "", fmt.Errorf("secret %q must be at least 32 bytes", key) } - _, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx) - if err == nil || ent.IsConstraintError(err) { - return nil + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(value). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", err + } } - return err + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", err + } + storedValue := strings.TrimSpace(stored.Value) + if len([]byte(storedValue)) < 32 { + return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return storedValue, nil +} + +func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) { + var lastErr error + for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ { + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + return stored, nil + } + if !isSecretNotFoundError(err) { + return nil, err + } + lastErr = err + if attempt == securitySecretReadRetryMax { + break + } + + timer := time.NewTimer(securitySecretReadRetryWait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func isSecretNotFoundError(err error) bool { + if err == nil { + return false + } + return ent.IsNotFound(err) || isSQLNoRowsError(err) +} + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") } func generateHexSecret(byteLength int) (string, error) {