Merge branch 'dev' into release
This commit is contained in:
@@ -3,17 +3,24 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/ent"
|
"github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
const securitySecretKeyJWT = "jwt_secret"
|
const (
|
||||||
|
securitySecretKeyJWT = "jwt_secret"
|
||||||
|
securitySecretReadRetryMax = 5
|
||||||
|
securitySecretReadRetryWait = 10 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
var readRandomBytes = rand.Read
|
var readRandomBytes = rand.Read
|
||||||
|
|
||||||
@@ -27,9 +34,14 @@ func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config
|
|||||||
|
|
||||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||||
if cfg.JWT.Secret != "" {
|
if cfg.JWT.Secret != "" {
|
||||||
if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil {
|
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("persist jwt secret: %w", err)
|
return fmt.Errorf("persist jwt secret: %w", err)
|
||||||
}
|
}
|
||||||
|
if storedSecret != cfg.JWT.Secret {
|
||||||
|
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
|
||||||
|
}
|
||||||
|
cfg.JWT.Secret = storedSecret
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,10 +81,12 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
|
|||||||
OnConflictColumns(securitysecret.FieldKey).
|
OnConflictColumns(securitysecret.FieldKey).
|
||||||
DoNothing().
|
DoNothing().
|
||||||
Exec(ctx); err != nil {
|
Exec(ctx); err != nil {
|
||||||
return "", false, err
|
if !isSQLNoRowsError(err) {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, err
|
return "", false, err
|
||||||
}
|
}
|
||||||
@@ -83,17 +97,72 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
|
|||||||
return value, value == generated, nil
|
return value, value == generated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error {
|
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
|
||||||
value = strings.TrimSpace(value)
|
value = strings.TrimSpace(value)
|
||||||
if len([]byte(value)) < 32 {
|
if len([]byte(value)) < 32 {
|
||||||
return fmt.Errorf("secret %q must be at least 32 bytes", key)
|
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx)
|
if err := client.SecuritySecret.Create().
|
||||||
if err == nil || ent.IsConstraintError(err) {
|
SetKey(key).
|
||||||
return nil
|
SetValue(value).
|
||||||
|
OnConflictColumns(securitysecret.FieldKey).
|
||||||
|
DoNothing().
|
||||||
|
Exec(ctx); err != nil {
|
||||||
|
if !isSQLNoRowsError(err) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
storedValue := strings.TrimSpace(stored.Value)
|
||||||
|
if len([]byte(storedValue)) < 32 {
|
||||||
|
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||||
|
}
|
||||||
|
return storedValue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
|
||||||
|
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return stored, nil
|
||||||
|
}
|
||||||
|
if !isSecretNotFoundError(err) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
if attempt == securitySecretReadRetryMax {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(securitySecretReadRetryWait)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSecretNotFoundError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ent.IsNotFound(err) || isSQLNoRowsError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSQLNoRowsError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateHexSecret(byteLength int) (string, error) {
|
func generateHexSecret(byteLength int) (string, error) {
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
|
|||||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
|
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
|
||||||
|
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
|
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
|
||||||
@@ -215,15 +216,17 @@ func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
|
|||||||
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
|
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
|
||||||
client := newSecuritySecretTestClient(t)
|
client := newSecuritySecretTestClient(t)
|
||||||
|
|
||||||
err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
|
_, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||||
|
|
||||||
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
|
stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
|
||||||
|
|
||||||
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
|
stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
|
||||||
|
|
||||||
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
|
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -232,7 +235,7 @@ func TestCreateSecuritySecretIfAbsent(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
|
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
|
||||||
client := newSecuritySecretTestClient(t)
|
client := newSecuritySecretTestClient(t)
|
||||||
err := createSecuritySecretIfAbsent(
|
_, err := createSecuritySecretIfAbsent(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
client,
|
client,
|
||||||
strings.Repeat("k", 101),
|
strings.Repeat("k", 101),
|
||||||
@@ -241,6 +244,68 @@ func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) {
|
||||||
|
client := newSecuritySecretTestClient(t)
|
||||||
|
require.NoError(t, client.Close())
|
||||||
|
|
||||||
|
_, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) {
|
||||||
|
client := newSecuritySecretTestClient(t)
|
||||||
|
created, err := client.SecuritySecret.Create().
|
||||||
|
SetKey("retry_success_key").
|
||||||
|
SetValue("retry-success-jwt-secret-value-32!!").
|
||||||
|
Save(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, created.ID, got.ID)
|
||||||
|
require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) {
|
||||||
|
client := newSecuritySecretTestClient(t)
|
||||||
|
|
||||||
|
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.True(t, isSecretNotFoundError(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) {
|
||||||
|
client := newSecuritySecretTestClient(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) {
|
||||||
|
client := newSecuritySecretTestClient(t)
|
||||||
|
require.NoError(t, client.Close())
|
||||||
|
|
||||||
|
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.False(t, isSecretNotFoundError(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecretNotFoundHelpers(t *testing.T) {
|
||||||
|
require.False(t, isSecretNotFoundError(nil))
|
||||||
|
require.False(t, isSQLNoRowsError(nil))
|
||||||
|
|
||||||
|
require.True(t, isSQLNoRowsError(sql.ErrNoRows))
|
||||||
|
require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows)))
|
||||||
|
require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set")))
|
||||||
|
|
||||||
|
require.True(t, isSecretNotFoundError(sql.ErrNoRows))
|
||||||
|
require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set")))
|
||||||
|
require.False(t, isSecretNotFoundError(errors.New("some other error")))
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateHexSecretReadError(t *testing.T) {
|
func TestGenerateHexSecretReadError(t *testing.T) {
|
||||||
originalRead := readRandomBytes
|
originalRead := readRandomBytes
|
||||||
readRandomBytes = func([]byte) (int, error) {
|
readRandomBytes = func([]byte) (int, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user