fix(auth): harden oauth identity upgrade paths
This commit is contained in:
@@ -14,10 +14,14 @@ import (
|
||||
|
||||
func normalizeOAuthSignupSource(signupSource string) string {
|
||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
||||
if signupSource == "" {
|
||||
switch signupSource {
|
||||
case "", "email":
|
||||
return "email"
|
||||
case "linuxdo", "wechat", "oidc":
|
||||
return signupSource
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
return signupSource
|
||||
}
|
||||
|
||||
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
|
||||
@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
return nil, nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
|
||||
if signupSource == "" {
|
||||
signupSource = "email"
|
||||
}
|
||||
signupSource = normalizeOAuthSignupSource(signupSource)
|
||||
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
|
||||
|
||||
user := &User{
|
||||
@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
|
||||
@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
|
||||
require.Empty(t, redeemRepo.updateCalls)
|
||||
}
|
||||
|
||||
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 42}
|
||||
emailCache := &emailCacheStub{
|
||||
data: &VerificationCodeData{
|
||||
Code: "246810",
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
&redeemCodeRepoStub{},
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
context.Background(),
|
||||
"fresh@example.com",
|
||||
"secret-123",
|
||||
"246810",
|
||||
"",
|
||||
" OIDC ",
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, userRepo.created, 1)
|
||||
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
|
||||
}
|
||||
|
||||
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
|
||||
userRepo := &userRepoStub{nextID: 43}
|
||||
emailCache := &emailCacheStub{
|
||||
data: &VerificationCodeData{
|
||||
Code: "246810",
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
|
||||
},
|
||||
}
|
||||
authService := newOAuthEmailFlowAuthService(
|
||||
userRepo,
|
||||
&redeemCodeRepoStub{},
|
||||
&refreshTokenCacheStub{},
|
||||
map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
},
|
||||
emailCache,
|
||||
)
|
||||
|
||||
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
|
||||
context.Background(),
|
||||
"fallback@example.com",
|
||||
"secret-123",
|
||||
"246810",
|
||||
"",
|
||||
"github",
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokenPair)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, userRepo.created, 1)
|
||||
require.Equal(t, "email", userRepo.created[0].SignupSource)
|
||||
}
|
||||
|
||||
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
|
||||
userRepo := &userRepoStub{}
|
||||
redeemRepo := &redeemCodeRepoStub{
|
||||
|
||||
@@ -5,10 +5,15 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
|
||||
|
||||
type authPendingIdentityScopedKeyLockRegistry struct {
|
||||
mu sync.Mutex
|
||||
locks map[string]*authPendingIdentityScopedKeyLockEntry
|
||||
}
|
||||
|
||||
type authPendingIdentityScopedKeyLockEntry struct {
|
||||
mu sync.Mutex
|
||||
refs int
|
||||
}
|
||||
|
||||
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
|
||||
return &authPendingIdentityScopedKeyLockRegistry{
|
||||
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
|
||||
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||
if len(normalized) == 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
|
||||
r.mu.Lock()
|
||||
for _, key := range normalized {
|
||||
entry := r.locks[key]
|
||||
if entry == nil {
|
||||
entry = &authPendingIdentityScopedKeyLockEntry{}
|
||||
r.locks[key] = entry
|
||||
}
|
||||
entry.refs++
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, entry := range entries {
|
||||
entry.mu.Lock()
|
||||
}
|
||||
|
||||
return func() {
|
||||
for i := len(entries) - 1; i >= 0; i-- {
|
||||
entries[i].mu.Unlock()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for idx, key := range normalized {
|
||||
entry := entries[idx]
|
||||
entry.refs--
|
||||
if entry.refs == 0 {
|
||||
delete(r.locks, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
deduped := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmed := strings.TrimSpace(key)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
deduped[trimmed] = struct{}{}
|
||||
}
|
||||
if len(deduped) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(deduped))
|
||||
for key := range deduped {
|
||||
normalized = append(normalized, key)
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func authPendingIdentityAdvisoryLockHash(key string) int64 {
|
||||
hasher := fnv.New64a()
|
||||
_, _ = hasher.Write([]byte(key))
|
||||
return int64(hasher.Sum64())
|
||||
}
|
||||
|
||||
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
|
||||
release := authPendingIdentityScopedKeyLocks.lock(keys...)
|
||||
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
|
||||
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
|
||||
return release, nil
|
||||
}
|
||||
|
||||
for _, key := range normalized {
|
||||
var rows entsql.Rows
|
||||
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
|
||||
release()
|
||||
return nil, err
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
|
||||
return release, nil
|
||||
}
|
||||
|
||||
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
|
||||
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
|
||||
if identityID != nil && *identityID > 0 {
|
||||
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
|
||||
return &AuthPendingIdentityService{entClient: entClient}
|
||||
}
|
||||
@@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s.entClient
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
client = tx.Client()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
|
||||
client = existingTx.Client()
|
||||
}
|
||||
|
||||
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer releaseLocks()
|
||||
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
if _, err := s.entClient.IdentityAdoptionDecision.Update().
|
||||
if _, err := client.IdentityAdoptionDecision.Update().
|
||||
Where(
|
||||
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
|
||||
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
|
||||
@@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
|
||||
}),
|
||||
).
|
||||
ClearIdentityID().
|
||||
Save(ctx); err != nil {
|
||||
Save(txCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
existing, err := s.entClient.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if existing == nil {
|
||||
create := s.entClient.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
create := client.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil && *input.IdentityID > 0 {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
|
||||
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
decisionID, err := create.
|
||||
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
|
||||
UpdateNewValues().
|
||||
ID(txCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return update.Save(ctx)
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func copyPendingMap(in map[string]any) map[string]any {
|
||||
|
||||
@@ -5,6 +5,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
|
||||
require.Nil(t, reloadedFirst.IdentityID)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("adoption-concurrent@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(RoleUser).
|
||||
SetStatus(StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-main").
|
||||
SetProviderSubject("union-concurrent").
|
||||
SetMetadata(map[string]any{}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-main",
|
||||
ProviderSubject: "union-concurrent",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstCreateStarted := make(chan struct{})
|
||||
releaseFirstCreate := make(chan struct{})
|
||||
var firstCreate sync.Once
|
||||
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
|
||||
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
|
||||
blocked := false
|
||||
if m.Op().Is(dbent.OpCreate) {
|
||||
firstCreate.Do(func() {
|
||||
blocked = true
|
||||
close(firstCreateStarted)
|
||||
})
|
||||
}
|
||||
if blocked {
|
||||
<-releaseFirstCreate
|
||||
}
|
||||
return next.Mutate(ctx, m)
|
||||
})
|
||||
})
|
||||
|
||||
type adoptionResult struct {
|
||||
decision *dbent.IdentityAdoptionDecision
|
||||
err error
|
||||
}
|
||||
|
||||
input := PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: true,
|
||||
}
|
||||
|
||||
results := make(chan adoptionResult, 2)
|
||||
go func() {
|
||||
decision, err := svc.UpsertAdoptionDecision(ctx, input)
|
||||
results <- adoptionResult{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
<-firstCreateStarted
|
||||
|
||||
go func() {
|
||||
decision, err := svc.UpsertAdoptionDecision(ctx, input)
|
||||
results <- adoptionResult{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(releaseFirstCreate)
|
||||
|
||||
first := <-results
|
||||
second := <-results
|
||||
|
||||
require.NoError(t, first.err)
|
||||
require.NoError(t, second.err)
|
||||
require.NotNil(t, first.decision)
|
||||
require.NotNil(t, second.decision)
|
||||
require.Equal(t, first.decision.ID, second.decision.ID)
|
||||
|
||||
count, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
loaded, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, loaded.IdentityID)
|
||||
require.Equal(t, identity.ID, *loaded.IdentityID)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
|
||||
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
Balance: grantPlan.Balance,
|
||||
Concurrency: grantPlan.Concurrency,
|
||||
Status: StatusActive,
|
||||
SignupSource: signupSource,
|
||||
}
|
||||
|
||||
if s.entClient != nil && invitationRedeemCode != nil {
|
||||
@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
TokenVersion: user.TokenVersion,
|
||||
TokenVersion: resolvedTokenVersion(user),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
|
||||
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
||||
// This ensures tokens issued before a password change cannot be refreshed
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
if claims.TokenVersion != resolvedTokenVersion(user) {
|
||||
return "", ErrTokenRevoked
|
||||
}
|
||||
|
||||
@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
|
||||
|
||||
data := &RefreshTokenData{
|
||||
UserID: user.ID,
|
||||
TokenVersion: user.TokenVersion,
|
||||
TokenVersion: resolvedTokenVersion(user),
|
||||
FamilyID: familyID,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(ttl),
|
||||
@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
|
||||
}
|
||||
|
||||
// 检查TokenVersion(密码更改后所有Token失效)
|
||||
if data.TokenVersion != user.TokenVersion {
|
||||
if data.TokenVersion != resolvedTokenVersion(user) {
|
||||
// TokenVersion不匹配,撤销整个Token家族
|
||||
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
|
||||
return nil, ErrTokenRevoked
|
||||
@@ -1492,3 +1495,14 @@ func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func resolvedTokenVersion(user *User) int64 {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
|
||||
sum := sha256.Sum256([]byte(material))
|
||||
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
|
||||
return user.TokenVersion ^ fingerprint
|
||||
}
|
||||
|
||||
@@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string {
|
||||
return urls
|
||||
}
|
||||
|
||||
func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||
if base.UsePKCEExplicit {
|
||||
return base.UsePKCE
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
|
||||
if base.ValidateIDTokenExplicit {
|
||||
return base.ValidateIDToken
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
|
||||
@@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
return fmt.Errorf("check existing settings: %w", err)
|
||||
}
|
||||
|
||||
oidcUsePKCEDefault := true
|
||||
oidcValidateIDTokenDefault := true
|
||||
if s != nil && s.cfg != nil {
|
||||
if s.cfg.OIDC.UsePKCEExplicit {
|
||||
oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
|
||||
}
|
||||
if s.cfg.OIDC.ValidateIDTokenExplicit {
|
||||
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
@@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyOIDCConnectRedirectURL: "",
|
||||
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
|
||||
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
|
||||
SettingKeyOIDCConnectUsePKCE: "true",
|
||||
SettingKeyOIDCConnectValidateIDToken: "true",
|
||||
SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
|
||||
SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
|
||||
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
|
||||
SettingKeyOIDCConnectClockSkewSeconds: "120",
|
||||
SettingKeyOIDCConnectRequireEmailVerified: "false",
|
||||
@@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
||||
result.OIDCConnectUsePKCE = raw == "true"
|
||||
} else {
|
||||
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
|
||||
result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
|
||||
}
|
||||
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
||||
result.OIDCConnectValidateIDToken = raw == "true"
|
||||
} else {
|
||||
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
|
||||
result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
|
||||
}
|
||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
|
||||
@@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
|
||||
}
|
||||
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
|
||||
effective.UsePKCE = raw == "true"
|
||||
} else {
|
||||
effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
|
||||
}
|
||||
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
||||
effective.ValidateIDToken = raw == "true"
|
||||
} else {
|
||||
effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
|
||||
}
|
||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.AllowedSigningAlgs = strings.TrimSpace(v)
|
||||
|
||||
@@ -118,8 +118,10 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t
|
||||
func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
|
||||
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
UsePKCE: true,
|
||||
ValidateIDToken: true,
|
||||
UsePKCE: true,
|
||||
UsePKCEExplicit: true,
|
||||
ValidateIDToken: true,
|
||||
ValidateIDTokenExplicit: true,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
|
||||
require.True(t, got.OIDCConnectValidateIDToken)
|
||||
}
|
||||
|
||||
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
UsePKCE: true,
|
||||
ValidateIDToken: true,
|
||||
},
|
||||
})
|
||||
|
||||
got := svc.parseSettings(map[string]string{
|
||||
SettingKeyOIDCConnectEnabled: "true",
|
||||
})
|
||||
|
||||
require.False(t, got.OIDCConnectUsePKCE)
|
||||
require.False(t, got.OIDCConnectValidateIDToken)
|
||||
}
|
||||
|
||||
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
@@ -163,6 +181,42 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok
|
||||
}
|
||||
|
||||
func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
Enabled: true,
|
||||
ProviderName: "OIDC",
|
||||
ClientID: "oidc-client",
|
||||
ClientSecret: "oidc-secret",
|
||||
IssuerURL: "https://issuer.example.com",
|
||||
AuthorizeURL: "https://issuer.example.com/auth",
|
||||
TokenURL: "https://issuer.example.com/token",
|
||||
UserInfoURL: "https://issuer.example.com/userinfo",
|
||||
JWKSURL: "https://issuer.example.com/jwks",
|
||||
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
|
||||
FrontendRedirectURL: "/auth/oidc/callback",
|
||||
Scopes: "openid email profile",
|
||||
TokenAuthMethod: "client_secret_post",
|
||||
UsePKCE: true,
|
||||
UsePKCEExplicit: true,
|
||||
ValidateIDToken: true,
|
||||
ValidateIDTokenExplicit: true,
|
||||
AllowedSigningAlgs: "RS256",
|
||||
ClockSkewSeconds: 120,
|
||||
},
|
||||
}
|
||||
|
||||
repo := &settingOIDCRepoStub{values: map[string]string{
|
||||
SettingKeyOIDCConnectEnabled: "true",
|
||||
}}
|
||||
svc := NewSettingService(repo, cfg)
|
||||
|
||||
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.UsePKCE)
|
||||
require.True(t, got.ValidateIDToken)
|
||||
}
|
||||
|
||||
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OIDC: config.OIDCConnectConfig{
|
||||
Enabled: true,
|
||||
@@ -192,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
|
||||
|
||||
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.True(t, got.UsePKCE)
|
||||
require.True(t, got.ValidateIDToken)
|
||||
require.False(t, got.UsePKCE)
|
||||
require.False(t, got.ValidateIDToken)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user