fix(profile): stabilize binding compatibility and frontend checks

This commit is contained in:
IanShaw027
2026-04-22 14:57:47 +08:00
parent 1aab084ecb
commit ca4e38aa01
30 changed files with 1072 additions and 97 deletions

View File

@@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
return
}
updatedUser, err := h.userService.UnbindUserAuthProvider(
updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
c.Request.Context(),
subject.UserID,
c.Param("provider"),
@@ -258,7 +258,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if h.authService != nil {
if unbound && h.authService != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
@@ -512,7 +512,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
var avatarSource *userProfileSourceContext
avatarValue := strings.TrimSpace(user.AvatarURL)
for _, summary := range thirdParty {
if avatarValue != "" && avatarValue == strings.TrimSpace(summary.DisplayName) {
if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
avatarSource = buildUserProfileSourceContext(summary.Provider)
break
}

View File

@@ -636,6 +636,50 @@ func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigure
require.Equal(t, int64(5), repo.user.TokenVersion)
}
func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 24,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Empty(t, repo.unbound)
require.Empty(t, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(4), repo.user.TokenVersion)
}
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -728,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
require.Equal(t, "wechat", resp.Data.Provider)
require.Equal(t, "GET", resp.Data.Method)
require.True(t, resp.Data.UseBrowserRedirect)
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
}

View File

@@ -85,7 +85,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@@ -93,7 +93,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@@ -101,7 +101,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
@@ -122,7 +122,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@@ -130,7 +130,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@@ -138,7 +138,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
@@ -159,7 +159,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
@@ -167,7 +167,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
@@ -175,7 +175,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"

View File

@@ -63,8 +63,20 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
@@ -129,6 +141,12 @@ func RegisterAuthRoutes(
h.Auth.CreateWeChatOAuthAccount,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
@@ -165,23 +183,5 @@ func RegisterAuthRoutes(
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
}
}

View File

@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// BindEmailIdentity verifies and binds a local email/password identity to the
@@ -69,6 +70,7 @@ func (s *AuthService) BindEmailIdentity(
if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
@@ -87,6 +89,7 @@ func (s *AuthService) BindEmailIdentity(
}
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
@@ -219,6 +222,12 @@ func (s *AuthService) updateBoundEmailIdentityWithClient(
return nil
}
func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
}
}
func replaceBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,

View File

@@ -6,6 +6,7 @@ import (
"context"
"database/sql"
"errors"
"sync"
"testing"
"time"
@@ -13,6 +14,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
@@ -54,6 +56,16 @@ func newAuthServiceForEmailBind(
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
}
func newAuthServiceForEmailBindWithRefreshCache(
t *testing.T,
settings map[string]string,
emailCache service.EmailCache,
defaultSubAssigner service.DefaultSubscriptionAssigner,
refreshTokenCache service.RefreshTokenCache,
) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
@@ -98,7 +110,7 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
emailSvc = service.NewEmailService(settingRepo, emailCache)
}
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner)
return svc, repo, client
}
@@ -427,6 +439,61 @@ func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t
require.Equal(t, 0, newIdentityCount)
}
func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
ctx := context.Background()
cache := &emailBindCacheStub{
data: &service.VerificationCodeData{
Code: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
},
}
refreshTokenCache := newEmailBindRefreshTokenCacheStub()
userRepo := newEmailBindUserRepoStub(&service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Username: "legacy-user",
PasswordHash: "old-hash",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
})
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-bind-email-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
}
emailService := service.NewEmailService(nil, cache)
svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil)
oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
ID: 41,
Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
}, "")
require.NoError(t, err)
updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
require.NoError(t, err)
require.NotNil(t, updatedUser)
storedUser, err := userRepo.GetByID(ctx, 41)
require.NoError(t, err)
require.Equal(t, "new@example.com", storedUser.Email)
require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
_, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
require.ErrorIs(t, err, service.ErrTokenRevoked)
_, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
}
type emailBindSettingRepoStub struct {
values map[string]string
}
@@ -527,3 +594,260 @@ func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int6
func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
return 0, nil
}
type emailBindRefreshTokenCacheStub struct {
mu sync.Mutex
tokens map[string]*service.RefreshTokenData
userSets map[int64]map[string]struct{}
families map[string]map[string]struct{}
}
func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
return &emailBindRefreshTokenCacheStub{
tokens: make(map[string]*service.RefreshTokenData),
userSets: make(map[int64]map[string]struct{}),
families: make(map[string]map[string]struct{}),
}
}
func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
cloned := *data
s.tokens[tokenHash] = &cloned
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
s.mu.Lock()
defer s.mu.Unlock()
data, ok := s.tokens[tokenHash]
if !ok {
return nil, service.ErrRefreshTokenNotFound
}
cloned := *data
return &cloned, nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.userSets[userID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.families {
delete(tokenSet, tokenHash)
}
}
delete(s.userSets, userID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
s.mu.Lock()
defer s.mu.Unlock()
for tokenHash := range s.families[familyID] {
delete(s.tokens, tokenHash)
for _, tokenSet := range s.userSets {
delete(tokenSet, tokenHash)
}
}
delete(s.families, familyID)
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.userSets[userID] == nil {
s.userSets[userID] = make(map[string]struct{})
}
s.userSets[userID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.families[familyID] == nil {
s.families[familyID] = make(map[string]struct{})
}
s.families[familyID][tokenHash] = struct{}{}
return nil
}
func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.userSets[userID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
tokenSet := s.families[familyID]
out := make([]string, 0, len(tokenSet))
for tokenHash := range tokenSet {
out = append(out, tokenHash)
}
return out, nil
}
func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.families[familyID][tokenHash]
return ok, nil
}
type emailBindUserRepoStub struct {
mu sync.Mutex
usersByID map[int64]*service.User
usersByEmail map[string]*service.User
}
func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
cloned := cloneEmailBindUser(user)
return &emailBindUserRepoStub{
usersByID: map[int64]*service.User{
cloned.ID: cloned,
},
usersByEmail: map[string]*service.User{
cloned.Email: cloned,
},
}
}
func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByID[id]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
s.mu.Lock()
defer s.mu.Unlock()
user, ok := s.usersByEmail[email]
if !ok {
return nil, service.ErrUserNotFound
}
return cloneEmailBindUser(user), nil
}
func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
s.mu.Lock()
defer s.mu.Unlock()
existing, ok := s.usersByID[user.ID]
if !ok {
return service.ErrUserNotFound
}
delete(s.usersByEmail, existing.Email)
cloned := cloneEmailBindUser(user)
s.usersByID[user.ID] = cloned
s.usersByEmail[cloned.Email] = cloned
return nil
}
func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
_, ok := s.usersByEmail[email]
return ok, nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
return nil
}
func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
func cloneEmailBindUser(user *service.User) *service.User {
if user == nil {
return nil
}
cloned := *user
return &cloned
}

View File

@@ -127,6 +127,7 @@ type UserIdentitySummary struct {
Bound bool `json:"bound"`
BoundCount int `json:"bound_count"`
DisplayName string `json:"display_name,omitempty"`
AvatarURL string `json:"-"`
SubjectHint string `json:"subject_hint,omitempty"`
ProviderKey string `json:"provider_key,omitempty"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
@@ -228,6 +229,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
normalizeLoadedUserTokenVersion(user)
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
@@ -323,29 +325,34 @@ func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUs
}
func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) (*User, error) {
user, _, err := s.UnbindUserAuthProviderWithResult(ctx, userID, provider)
return user, err
}
func (s *UserService) UnbindUserAuthProviderWithResult(ctx context.Context, userID int64, provider string) (*User, bool, error) {
provider = normalizeUserIdentityProvider(provider)
if provider == "" || provider == "email" {
return nil, ErrIdentityProviderInvalid
return nil, false, ErrIdentityProviderInvalid
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
return nil, false, fmt.Errorf("get user: %w", err)
}
records, err := s.listUserAuthIdentities(ctx, userID)
if err != nil {
return nil, err
return nil, false, err
}
if len(filterUserAuthIdentities(records, provider)) == 0 {
return user, nil
return user, false, nil
}
if !s.canUnbindProvider(provider, user, records) {
return nil, ErrIdentityUnbindLastMethod
return nil, false, ErrIdentityUnbindLastMethod
}
if err := s.userRepo.UnbindUserAuthProvider(ctx, userID, provider); err != nil {
return nil, err
return nil, false, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
@@ -353,9 +360,9 @@ func (s *UserService) UnbindUserAuthProvider(ctx context.Context, userID int64,
updatedUser, err := s.GetProfile(ctx, userID)
if err != nil {
return nil, err
return nil, false, err
}
return updatedUser, nil
return updatedUser, true, nil
}
// UpdateProfile 更新用户资料
@@ -655,6 +662,7 @@ func (s *UserService) buildProviderIdentitySummary(provider string, user *User,
summary.Bound = true
summary.BoundCount = len(filtered)
summary.DisplayName = userAuthIdentityDisplayName(primary)
summary.AvatarURL = strings.TrimSpace(firstStringIdentityValue(primary.Metadata, "avatar_url", "suggested_avatar_url", "headimgurl"))
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
summary.VerifiedAt = primary.VerifiedAt
@@ -672,7 +680,7 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U
return false
}
if s.buildEmailIdentitySummary(user, records).Bound {
if s.canUseEmailAsSignInMethod(user, records) {
return true
}
@@ -688,6 +696,44 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U
return false
}
func (s *UserService) canUseEmailAsSignInMethod(user *User, records []UserAuthIdentityRecord) bool {
if user == nil {
return false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return false
}
if emailSignupSourceAllowsLogin(user.SignupSource) {
return true
}
for _, record := range filterUserAuthIdentities(records, "email") {
if emailIdentitySupportsSignIn(record) {
return true
}
}
return false
}
func emailSignupSourceAllowsLogin(signupSource string) bool {
signupSource = strings.ToLower(strings.TrimSpace(signupSource))
return signupSource == "" || signupSource == "email"
}
func emailIdentitySupportsSignIn(record UserAuthIdentityRecord) bool {
source := strings.TrimSpace(firstStringIdentityValue(record.Metadata, "source"))
switch source {
case "auth_service_email_bind", "auth_service_login_backfill", "auth_service_dual_write":
return true
default:
return false
}
}
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
if userID <= 0 || s == nil || s.userRepo == nil {
return nil, nil
@@ -709,11 +755,11 @@ func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, err
path := ""
switch provider {
case "linuxdo":
path = "/api/v1/auth/oauth/linuxdo/start"
path = "/api/v1/auth/oauth/linuxdo/bind/start"
case "oidc":
path = "/api/v1/auth/oauth/oidc/start"
path = "/api/v1/auth/oauth/oidc/bind/start"
case "wechat":
path = "/api/v1/auth/oauth/wechat/start"
path = "/api/v1/auth/oauth/wechat/bind/start"
default:
return "", ErrIdentityProviderInvalid
}
@@ -889,12 +935,20 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
normalizeLoadedUserTokenVersion(user)
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
func normalizeLoadedUserTokenVersion(user *User) {
if user == nil {
return
}
user.TokenVersion = resolvedTokenVersion(user)
}
// TouchLastActive 通过防抖更新 users.last_active_at减少鉴权热路径写放大。
// 该操作为尽力而为,不应中断正常请求。
func (s *UserService) TouchLastActive(ctx context.Context, userID int64) {

View File

@@ -387,6 +387,70 @@ func TestUnbindUserAuthProviderRejectsLastRemainingLoginMethod(t *testing.T) {
require.Empty(t, repo.unboundProviders)
}
func TestGetProfileIdentitySummaries_DoesNotTreatOAuthOnlyCompatEmailAsAlternativeLoginMethod(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 10,
Email: "oauth-only@example.com",
SignupSource: "oidc",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "oidc-only-subject",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 10, repo.getByIDUser)
require.NoError(t, err)
require.False(t, summaries.OIDC.CanUnbind)
_, err = svc.UnbindUserAuthProvider(context.Background(), 10, "oidc")
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
require.Empty(t, repo.unboundProviders)
}
func TestGetProfileIdentitySummaries_DoesNotTreatCompatBackfilledEmailIdentityAsAlternativeLoginMethod(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 11,
Email: "oauth-only@example.com",
SignupSource: "wechat",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "oauth-only@example.com",
Metadata: map[string]any{
"backfill_source": "users.email",
"migration": "109_auth_identity_compat_backfill",
},
},
{
ProviderType: "wechat",
ProviderKey: "wechat",
ProviderSubject: "wechat-only-subject",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 11, repo.getByIDUser)
require.NoError(t, err)
require.True(t, summaries.Email.Bound)
require.False(t, summaries.WeChat.CanUnbind)
_, err = svc.UnbindUserAuthProvider(context.Background(), 11, "wechat")
require.ErrorIs(t, err, ErrIdentityUnbindLastMethod)
require.Empty(t, repo.unboundProviders)
}
func TestUnbindUserAuthProviderRemovesProviderAndReturnsUpdatedProfile(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
@@ -451,6 +515,42 @@ func TestGetProfileIdentitySummaries_HidesBindActionWhenProviderExplicitlyDisabl
require.Empty(t, summaries.LinuxDo.BindStartPath)
}
func TestGetProfileIdentitySummaries_UsesBindStartRoute(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 16,
Email: "alice@example.com",
},
identities: []UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "alice@example.com",
},
},
}
svc := NewUserService(repo, nil, nil, nil)
summaries, err := svc.GetProfileIdentitySummaries(context.Background(), 16, repo.getByIDUser)
require.NoError(t, err)
require.Equal(
t,
"/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.LinuxDo.BindStartPath,
)
require.Equal(
t,
"/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.OIDC.BindStartPath,
)
require.Equal(
t,
"/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile",
summaries.WeChat.BindStartPath,
)
}
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
repo := &mockUserRepo{}
svc := NewUserService(repo, nil, nil, nil) // billingCache = nil