Merge branch 'dev' into release

This commit is contained in:
yangjianbo
2026-02-12 12:12:40 +08:00
2 changed files with 148 additions and 14 deletions

View File

@@ -3,17 +3,24 @@ package repository
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"log" "log"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
) )
const securitySecretKeyJWT = "jwt_secret" const (
securitySecretKeyJWT = "jwt_secret"
securitySecretReadRetryMax = 5
securitySecretReadRetryWait = 10 * time.Millisecond
)
var readRandomBytes = rand.Read var readRandomBytes = rand.Read
@@ -27,9 +34,14 @@ func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
if cfg.JWT.Secret != "" { if cfg.JWT.Secret != "" {
if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil { storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
if err != nil {
return fmt.Errorf("persist jwt secret: %w", err) return fmt.Errorf("persist jwt secret: %w", err)
} }
if storedSecret != cfg.JWT.Secret {
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
}
cfg.JWT.Secret = storedSecret
return nil return nil
} }
@@ -69,10 +81,12 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
OnConflictColumns(securitysecret.FieldKey). OnConflictColumns(securitysecret.FieldKey).
DoNothing(). DoNothing().
Exec(ctx); err != nil { Exec(ctx); err != nil {
return "", false, err if !isSQLNoRowsError(err) {
return "", false, err
}
} }
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil { if err != nil {
return "", false, err return "", false, err
} }
@@ -83,17 +97,72 @@ func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client,
return value, value == generated, nil return value, value == generated, nil
} }
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error { func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
if len([]byte(value)) < 32 { if len([]byte(value)) < 32 {
return fmt.Errorf("secret %q must be at least 32 bytes", key) return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
} }
_, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx) if err := client.SecuritySecret.Create().
if err == nil || ent.IsConstraintError(err) { SetKey(key).
return nil SetValue(value).
OnConflictColumns(securitysecret.FieldKey).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return "", err
}
} }
return err
stored, err := querySecuritySecretWithRetry(ctx, client, key)
if err != nil {
return "", err
}
storedValue := strings.TrimSpace(stored.Value)
if len([]byte(storedValue)) < 32 {
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
}
return storedValue, nil
}
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
var lastErr error
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
if err == nil {
return stored, nil
}
if !isSecretNotFoundError(err) {
return nil, err
}
lastErr = err
if attempt == securitySecretReadRetryMax {
break
}
timer := time.NewTimer(securitySecretReadRetryWait)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
}
return nil, lastErr
}
func isSecretNotFoundError(err error) bool {
if err == nil {
return false
}
return ent.IsNotFound(err) || isSQLNoRowsError(err)
}
func isSQLNoRowsError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
} }
func generateHexSecret(byteLength int) (string, error) { func generateHexSecret(byteLength int) (string, error) {

View File

@@ -124,6 +124,7 @@ func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value) require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
} }
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) { func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
@@ -215,15 +216,17 @@ func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
func TestCreateSecuritySecretIfAbsent(t *testing.T) { func TestCreateSecuritySecretIfAbsent(t *testing.T) {
client := newSecuritySecretTestClient(t) client := newSecuritySecretTestClient(t)
err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") _, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "at least 32 bytes") require.Contains(t, err.Error(), "at least 32 bytes")
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background()) count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -232,7 +235,7 @@ func TestCreateSecuritySecretIfAbsent(t *testing.T) {
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
client := newSecuritySecretTestClient(t) client := newSecuritySecretTestClient(t)
err := createSecuritySecretIfAbsent( _, err := createSecuritySecretIfAbsent(
context.Background(), context.Background(),
client, client,
strings.Repeat("k", 101), strings.Repeat("k", 101),
@@ -241,6 +244,68 @@ func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long")
require.Error(t, err)
}
func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) {
client := newSecuritySecretTestClient(t)
created, err := client.SecuritySecret.Create().
SetKey("retry_success_key").
SetValue("retry-success-jwt-secret-value-32!!").
Save(context.Background())
require.NoError(t, err)
got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key")
require.NoError(t, err)
require.Equal(t, created.ID, got.ID)
require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value)
}
func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) {
client := newSecuritySecretTestClient(t)
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key")
require.Error(t, err)
require.True(t, isSecretNotFoundError(err))
}
func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) {
client := newSecuritySecretTestClient(t)
ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2)
defer cancel()
_, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key")
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) {
client := newSecuritySecretTestClient(t)
require.NoError(t, client.Close())
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key")
require.Error(t, err)
require.False(t, isSecretNotFoundError(err))
}
func TestSecretNotFoundHelpers(t *testing.T) {
require.False(t, isSecretNotFoundError(nil))
require.False(t, isSQLNoRowsError(nil))
require.True(t, isSQLNoRowsError(sql.ErrNoRows))
require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows)))
require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set")))
require.True(t, isSecretNotFoundError(sql.ErrNoRows))
require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set")))
require.False(t, isSecretNotFoundError(errors.New("some other error")))
}
func TestGenerateHexSecretReadError(t *testing.T) { func TestGenerateHexSecretReadError(t *testing.T) {
originalRead := readRandomBytes originalRead := readRandomBytes
readRandomBytes = func([]byte) (int, error) { readRandomBytes = func([]byte) (int, error) {