diff --git a/backend/internal/repository/security_secret_bootstrap.go b/backend/internal/repository/security_secret_bootstrap.go index 85fdbf08..e773c238 100644 --- a/backend/internal/repository/security_secret_bootstrap.go +++ b/backend/internal/repository/security_secret_bootstrap.go @@ -3,17 +3,24 @@ package repository import ( "context" "crypto/rand" + "database/sql" "encoding/hex" + "errors" "fmt" "log" "strings" + "time" "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/internal/config" ) -const securitySecretKeyJWT = "jwt_secret" +const ( + securitySecretKeyJWT = "jwt_secret" + securitySecretReadRetryMax = 5 + securitySecretReadRetryWait = 10 * time.Millisecond +) var readRandomBytes = rand.Read @@ -27,9 +34,14 @@ func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) if cfg.JWT.Secret != "" { - if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil { + storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret) + if err != nil { return fmt.Errorf("persist jwt secret: %w", err) } + if storedSecret != cfg.JWT.Secret { + log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.") + } + cfg.JWT.Secret = storedSecret return nil } @@ -69,10 +81,12 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, OnConflictColumns(securitysecret.FieldKey). DoNothing(). Exec(ctx); err != nil { - return "", false, err + if !isSQLNoRowsError(err) { + return "", false, err + } } - stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + stored, err := querySecuritySecretWithRetry(ctx, client, key) if err != nil { return "", false, err } @@ -83,17 +97,72 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, return value, value == generated, nil } -func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error { +func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) { value = strings.TrimSpace(value) if len([]byte(value)) < 32 { - return fmt.Errorf("secret %q must be at least 32 bytes", key) + return "", fmt.Errorf("secret %q must be at least 32 bytes", key) } - _, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx) - if err == nil || ent.IsConstraintError(err) { - return nil + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(value). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", err + } } - return err + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", err + } + storedValue := strings.TrimSpace(stored.Value) + if len([]byte(storedValue)) < 32 { + return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return storedValue, nil +} + +func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) { + var lastErr error + for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ { + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + return stored, nil + } + if !isSecretNotFoundError(err) { + return nil, err + } + lastErr = err + if attempt == securitySecretReadRetryMax { + break + } + + timer := time.NewTimer(securitySecretReadRetryWait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func isSecretNotFoundError(err error) bool { + if err == nil { + return false + } + return ent.IsNotFound(err) || isSQLNoRowsError(err) +} + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") } func generateHexSecret(byteLength int) (string, error) { diff --git a/backend/internal/repository/security_secret_bootstrap_test.go b/backend/internal/repository/security_secret_bootstrap_test.go index f56810e9..288edf33 100644 --- a/backend/internal/repository/security_secret_bootstrap_test.go +++ b/backend/internal/repository/security_secret_bootstrap_test.go @@ -124,6 +124,7 @@ func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) { stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) require.NoError(t, err) require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) } func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) { @@ -215,15 +216,17 @@ func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) { func TestCreateSecuritySecretIfAbsent(t *testing.T) { client := newSecuritySecretTestClient(t) - err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") + _, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") require.Error(t, err) require.Contains(t, err.Error(), "at least 32 bytes") - err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") + stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) - err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") + stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background()) require.NoError(t, err) @@ -232,7 +235,7 @@ func TestCreateSecuritySecretIfAbsent(t *testing.T) { func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { client := newSecuritySecretTestClient(t) - err := createSecuritySecretIfAbsent( + _, err := createSecuritySecretIfAbsent( context.Background(), client, strings.Repeat("k", 101), @@ -241,6 +244,68 @@ func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { require.Error(t, err) } +func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long") + require.Error(t, err) +} + +func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) { + client := newSecuritySecretTestClient(t) + created, err := client.SecuritySecret.Create(). + SetKey("retry_success_key"). + SetValue("retry-success-jwt-secret-value-32!!"). + Save(context.Background()) + require.NoError(t, err) + + got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key") + require.NoError(t, err) + require.Equal(t, created.ID, got.ID) + require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value) +} + +func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key") + require.Error(t, err) + require.True(t, isSecretNotFoundError(err)) +} + +func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) { + client := newSecuritySecretTestClient(t) + ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2) + defer cancel() + + _, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key") + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key") + require.Error(t, err) + require.False(t, isSecretNotFoundError(err)) +} + +func TestSecretNotFoundHelpers(t *testing.T) { + require.False(t, isSecretNotFoundError(nil)) + require.False(t, isSQLNoRowsError(nil)) + + require.True(t, isSQLNoRowsError(sql.ErrNoRows)) + require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows))) + require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set"))) + + require.True(t, isSecretNotFoundError(sql.ErrNoRows)) + require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set"))) + require.False(t, isSecretNotFoundError(errors.New("some other error"))) +} + func TestGenerateHexSecretReadError(t *testing.T) { originalRead := readRandomBytes readRandomBytes = func([]byte) (int, error) {