fix(repository): 修复 JWT 密钥引导冲突一致性与并发读取竞态
This commit is contained in:
@@ -3,17 +3,24 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const securitySecretKeyJWT = "jwt_secret"
|
||||
const (
|
||||
securitySecretKeyJWT = "jwt_secret"
|
||||
securitySecretReadRetryMax = 5
|
||||
securitySecretReadRetryWait = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
var readRandomBytes = rand.Read
|
||||
|
||||
@@ -27,9 +34,14 @@ func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config
|
||||
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
if cfg.JWT.Secret != "" {
|
||||
if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil {
|
||||
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("persist jwt secret: %w", err)
|
||||
}
|
||||
if storedSecret != cfg.JWT.Secret {
|
||||
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
|
||||
}
|
||||
cfg.JWT.Secret = storedSecret
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -69,10 +81,12 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
|
||||
OnConflictColumns(securitysecret.FieldKey).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return "", false, err
|
||||
if !isSQLNoRowsError(err) {
|
||||
return "", false, err
|
||||
}
|
||||
}
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
@@ -83,17 +97,72 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
|
||||
return value, value == generated, nil
|
||||
}
|
||||
|
||||
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error {
|
||||
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||||
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
|
||||
_, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx)
|
||||
if err == nil || ent.IsConstraintError(err) {
|
||||
return nil
|
||||
if err := client.SecuritySecret.Create().
|
||||
SetKey(key).
|
||||
SetValue(value).
|
||||
OnConflictColumns(securitysecret.FieldKey).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if !isSQLNoRowsError(err) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
||||
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
storedValue := strings.TrimSpace(stored.Value)
|
||||
if len([]byte(storedValue)) < 32 {
|
||||
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return storedValue, nil
|
||||
}
|
||||
|
||||
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
if err == nil {
|
||||
return stored, nil
|
||||
}
|
||||
if !isSecretNotFoundError(err) {
|
||||
return nil, err
|
||||
}
|
||||
lastErr = err
|
||||
if attempt == securitySecretReadRetryMax {
|
||||
break
|
||||
}
|
||||
|
||||
timer := time.NewTimer(securitySecretReadRetryWait)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func isSecretNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return ent.IsNotFound(err) || isSQLNoRowsError(err)
|
||||
}
|
||||
|
||||
func isSQLNoRowsError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
|
||||
}
|
||||
|
||||
func generateHexSecret(byteLength int) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user