feat(security): 启动时自动迁移并持久化JWT密钥
- 新增 security_secrets 表及 Ent schema 用于存储系统级密钥 - 启动阶段支持无 jwt.secret 配置并在数据库中自动生成持久化 - 在 Ent 初始化后补齐密钥并执行完整配置校验 - 增加并发与异常分支单元测试,覆盖密钥引导核心路径 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// 启动阶段:从配置或数据库中确保系统密钥可用。
|
||||
if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。
|
||||
if err := cfg.Validate(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err)
|
||||
}
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
|
||||
@@ -48,6 +48,11 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
|
||||
|
||||
// security_secrets table should exist
|
||||
var securitySecretsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
|
||||
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
|
||||
|
||||
// user_allowed_groups table should exist
|
||||
var uagRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
|
||||
|
||||
108
backend/internal/repository/security_secret_bootstrap.go
Normal file
108
backend/internal/repository/security_secret_bootstrap.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const securitySecretKeyJWT = "jwt_secret"
|
||||
|
||||
var readRandomBytes = rand.Read
|
||||
|
||||
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("nil config")
|
||||
}
|
||||
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
if cfg.JWT.Secret != "" {
|
||||
if err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret); err != nil {
|
||||
return fmt.Errorf("persist jwt secret: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ensure jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
|
||||
if created {
|
||||
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
|
||||
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
if err == nil {
|
||||
value := strings.TrimSpace(existing.Value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return value, false, nil
|
||||
}
|
||||
if !ent.IsNotFound(err) {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
generated, err := generateHexSecret(byteLength)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if err := client.SecuritySecret.Create().
|
||||
SetKey(key).
|
||||
SetValue(generated).
|
||||
OnConflictColumns(securitysecret.FieldKey).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
value := strings.TrimSpace(stored.Value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return value, value == generated, nil
|
||||
}
|
||||
|
||||
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
|
||||
_, err := client.SecuritySecret.Create().SetKey(key).SetValue(value).Save(ctx)
|
||||
if err == nil || ent.IsConstraintError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func generateHexSecret(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 32
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := readRandomBytes(buf); err != nil {
|
||||
return "", fmt.Errorf("generate random secret: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
272
backend/internal/repository/security_secret_bootstrap_test.go
Normal file
272
backend/internal/repository/security_secret_bootstrap_test.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newSecuritySecretTestClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
name := strings.ReplaceAll(t.Name(), "/", "_")
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name)
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsNilInputs(t *testing.T) {
|
||||
err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil ent client")
|
||||
|
||||
client := newSecuritySecretTestClient(t)
|
||||
err = ensureBootstrapSecrets(context.Background(), client, nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil config")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, cfg.JWT.Secret)
|
||||
require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cfg.JWT.Secret, stored.Value)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"},
|
||||
}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().
|
||||
SetKey(securitySecretKeyJWT).
|
||||
SetValue("existing-jwt-secret-32bytes-long!!!!").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().
|
||||
SetKey("trimmed_key").
|
||||
SetValue(" existing-trimmed-secret-32bytes-long!! ").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32)
|
||||
require.NoError(t, err)
|
||||
require.False(t, created)
|
||||
require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
tooLongKey := strings.Repeat("k", 101)
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
const goroutines = 8
|
||||
key := "concurrent_bootstrap_key"
|
||||
|
||||
values := make([]string, goroutines)
|
||||
createdFlags := make([]bool, goroutines)
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := range errs {
|
||||
require.NoError(t, errs[i])
|
||||
require.NotEmpty(t, values[i])
|
||||
}
|
||||
for i := 1; i < len(values); i++ {
|
||||
require.Equal(t, values[0], values[i])
|
||||
}
|
||||
|
||||
createdCount := 0
|
||||
for _, created := range createdFlags {
|
||||
if created {
|
||||
createdCount++
|
||||
}
|
||||
}
|
||||
require.GreaterOrEqual(t, createdCount, 1)
|
||||
require.LessOrEqual(t, createdCount, 1)
|
||||
|
||||
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
originalRead := readRandomBytes
|
||||
readRandomBytes = func([]byte) (int, error) {
|
||||
return 0, errors.New("boom")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
readRandomBytes = originalRead
|
||||
})
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "boom")
|
||||
}
|
||||
|
||||
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
|
||||
err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
|
||||
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
err := createSecuritySecretIfAbsent(
|
||||
context.Background(),
|
||||
client,
|
||||
strings.Repeat("k", 101),
|
||||
"valid-jwt-secret-value-32bytes-long",
|
||||
)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGenerateHexSecretReadError(t *testing.T) {
|
||||
originalRead := readRandomBytes
|
||||
readRandomBytes = func([]byte) (int, error) {
|
||||
return 0, errors.New("read random failed")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
readRandomBytes = originalRead
|
||||
})
|
||||
|
||||
_, err := generateHexSecret(32)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "read random failed")
|
||||
}
|
||||
|
||||
func TestGenerateHexSecretLengths(t *testing.T) {
|
||||
v1, err := generateHexSecret(0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v1, 64)
|
||||
_, err = hex.DecodeString(v1)
|
||||
require.NoError(t, err)
|
||||
|
||||
v2, err := generateHexSecret(16)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v2, 32)
|
||||
_, err = hex.DecodeString(v2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, v1, v2)
|
||||
}
|
||||
Reference in New Issue
Block a user