Files
sub2api/backend/internal/repository/security_secret_bootstrap.go

178 lines
4.5 KiB
Go

package repository
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
securitySecretKeyJWT = "jwt_secret"
securitySecretReadRetryMax = 5
securitySecretReadRetryWait = 10 * time.Millisecond
)
var readRandomBytes = rand.Read
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
if cfg == nil {
return fmt.Errorf("nil config")
}
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
if cfg.JWT.Secret != "" {
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
if err != nil {
return fmt.Errorf("persist jwt secret: %w", err)
}
if storedSecret != cfg.JWT.Secret {
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
}
cfg.JWT.Secret = storedSecret
return nil
}
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
if err != nil {
return fmt.Errorf("ensure jwt secret: %w", err)
}
cfg.JWT.Secret = secret
if created {
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
}
return nil
}
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
value := strings.TrimSpace(existing.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, false, nil
}
if !ent.IsNotFound(err) {
return "", false, err
}
generated, err := generateHexSecret(byteLength)
if err != nil {
return "", false, err
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(generated).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", false, err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", false, err
}
value := strings.TrimSpace(stored.Value)
if len([]byte(value)) < 32 {
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return value, value == generated, nil
}
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
value = strings.TrimSpace(value)
if len([]byte(value)) < 32 {
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
}
if err := client.SecuritySecret.Create().
SetKey(key).
SetValue(value).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", err
}
}
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", err
}
storedValue := strings.TrimSpace(stored.Value)
if len([]byte(storedValue)) < 32 {
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return storedValue, nil
}
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
var lastErr error
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
return stored, nil
}
if !isSecretNotFoundError(err) {
return nil, err
}
lastErr = err
if attempt == securitySecretReadRetryMax {
break
}
timer := time.NewTimer(securitySecretReadRetryWait)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
return nil, lastErr
}
func isSecretNotFoundError(err error) bool {
if err == nil {
return false
}
return ent.IsNotFound(err) || isSQLNoRowsError(err)
}
func isSQLNoRowsError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
}
func generateHexSecret(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 32
}
buf := make([]byte, byteLength)
if _, err := readRandomBytes(buf); err != nil {
return "", fmt.Errorf("generate random secret: %w", err)
}
return hex.EncodeToString(buf), nil
}