fix(auth): harden oauth identity upgrade paths

This commit is contained in:
IanShaw027
2026-04-22 14:56:56 +08:00
parent 3d29f7c2fa
commit 36aed35957
32 changed files with 2365 additions and 262 deletions

View File

@@ -14,10 +14,14 @@ import (
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return signupSource
default:
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
signupSource = "email"
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{
@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {

View File

@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
require.Empty(t, redeemRepo.updateCalls)
}
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"",
" OIDC ",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fallback@example.com",
"secret-123",
"246810",
"",
"github",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "email", userRepo.created[0].SignupSource)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{

View File

@@ -5,10 +5,15 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"sort"
"strings"
"sync"
"time"
"entgo.io/ent/dialect"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
entClient *dbent.Client
}
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
type authPendingIdentityScopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*authPendingIdentityScopedKeyLockEntry
}
type authPendingIdentityScopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
return &authPendingIdentityScopedKeyLockRegistry{
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
}
}
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &authPendingIdentityScopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func authPendingIdentityAdvisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
release := authPendingIdentityScopedKeyLocks.lock(keys...)
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
var rows entsql.Rows
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
}
return keys
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
@@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
return nil, fmt.Errorf("pending auth ent client is not configured")
}
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
client := s.entClient
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
client = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
client = existingTx.Client()
}
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
if err != nil {
return nil, err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := s.entClient.IdentityAdoptionDecision.Update().
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
@@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
}),
).
ClearIdentityID().
Save(ctx); err != nil {
Save(txCtx); err != nil {
return nil, err
}
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return nil, err
}
return update.Save(ctx)
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
if err != nil {
return nil, err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
return decision, nil
}
func copyPendingMap(in map[string]any) map[string]any {

View File

@@ -5,6 +5,7 @@ package service
import (
"context"
"database/sql"
"sync"
"testing"
"time"
@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
require.Nil(t, reloadedFirst.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption-concurrent@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-concurrent").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-concurrent",
},
})
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion {
if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@@ -1492,3 +1495,14 @@ func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func resolvedTokenVersion(user *User) int64 {
if user == nil {
return 0
}
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
sum := sha256.Sum256([]byte(material))
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}

View File

@@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string {
return urls
}
func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.UsePKCEExplicit {
return base.UsePKCE
}
return false
}
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.ValidateIDTokenExplicit {
return base.ValidateIDToken
}
return false
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates, err := s.buildSystemSettingsUpdates(ctx, settings)
@@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
return fmt.Errorf("check existing settings: %w", err)
}
oidcUsePKCEDefault := true
oidcValidateIDTokenDefault := true
if s != nil && s.cfg != nil {
if s.cfg.OIDC.UsePKCEExplicit {
oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
}
if s.cfg.OIDC.ValidateIDTokenExplicit {
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
}
}
// 初始化默认设置
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
@@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectRedirectURL: "",
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
SettingKeyOIDCConnectUsePKCE: "true",
SettingKeyOIDCConnectValidateIDToken: "true",
SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
SettingKeyOIDCConnectClockSkewSeconds: "120",
SettingKeyOIDCConnectRequireEmailVerified: "false",
@@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true"
} else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE
result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true"
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
@@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
}
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true"
} else {
effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
}
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
} else {
effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
}
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)

View File

@@ -118,8 +118,10 @@ func TestSettingService_ParseSettings_PreservesOptionalOIDCCompatibilityFlags(t
func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValues(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
ValidateIDToken: true,
UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true,
ValidateIDTokenExplicit: true,
},
})
@@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
require.True(t, got.OIDCConnectValidateIDToken)
}
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
ValidateIDToken: true,
},
})
got := svc.parseSettings(map[string]string{
SettingKeyOIDCConnectEnabled: "true",
})
require.False(t, got.OIDCConnectUsePKCE)
require.False(t, got.OIDCConnectValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
@@ -163,6 +181,42 @@ func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTok
}
func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true,
ValidateIDTokenExplicit: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{
SettingKeyOIDCConnectEnabled: "true",
}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
@@ -192,6 +246,6 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken)
require.False(t, got.UsePKCE)
require.False(t, got.ValidateIDToken)
}