Close profile identity and avatar loop

This commit is contained in:
IanShaw027
2026-04-21 00:11:03 +08:00
parent f73117f9b1
commit 9204145746
14 changed files with 801 additions and 35 deletions

View File

@@ -296,6 +296,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
c.Request.Context(),
h.entClient(),
h.authService,
h.userService,
pendingSession,
decision,
&user.ID,

View File

@@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}

View File

@@ -852,6 +852,7 @@ func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -938,6 +939,12 @@ func applyPendingOAuthBinding(
}
}
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
}
return tx.Commit()
}
@@ -945,6 +952,7 @@ func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
@@ -953,6 +961,7 @@ func applyPendingOAuthAdoption(
ctx,
client,
authService,
userService,
session,
decision,
overrideUserID,
@@ -1092,7 +1101,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
})
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil {
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1188,7 +1197,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil {
if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
@@ -1278,7 +1287,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil {
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}

View File

@@ -152,6 +152,11 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
avatar := loadUserAvatarRecord(t, client, userEntity.ID)
require.NotNil(t, avatar)
require.Equal(t, "remote_url", avatar.StorageProvider)
require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
@@ -1242,6 +1247,18 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants (
UNIQUE(user_id, provider_type, grant_reason)
)`)
require.NoError(t, err)
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS user_avatars (
user_id INTEGER PRIMARY KEY,
storage_provider TEXT NOT NULL,
storage_key TEXT NOT NULL DEFAULT '',
url TEXT NOT NULL,
content_type TEXT NOT NULL DEFAULT '',
byte_size INTEGER NOT NULL DEFAULT 0,
sha256 TEXT NOT NULL DEFAULT '',
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
@@ -1492,6 +1509,35 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin
return payload
}
type oauthPendingFlowAvatarRecord struct {
StorageProvider string
URL string
}
func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
t.Helper()
var rows entsql.Rows
err := client.Driver().Query(
context.Background(),
`SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
[]any{userID},
&rows,
)
require.NoError(t, err)
defer rows.Close()
if !rows.Next() {
require.NoError(t, rows.Err())
return nil
}
var record oauthPendingFlowAvatarRecord
require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
require.NoError(t, rows.Err())
return &record
}
func countProviderGrantRecords(
t *testing.T,
client *dbent.Client,
@@ -1604,16 +1650,95 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
return r.client.User.DeleteOneID(id).Exec(ctx)
}
func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
return nil, nil
func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
driver := r.client.Driver()
if tx := dbent.TxFromContext(ctx); tx != nil {
driver = tx.Client().Driver()
}
var rows entsql.Rows
if err := driver.Query(
ctx,
`SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
[]any{userID},
&rows,
); err != nil {
return nil, err
}
defer rows.Close()
if !rows.Next() {
return nil, rows.Err()
}
var avatar service.UserAvatar
if err := rows.Scan(
&avatar.StorageProvider,
&avatar.StorageKey,
&avatar.URL,
&avatar.ContentType,
&avatar.ByteSize,
&avatar.SHA256,
); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return &avatar, nil
}
func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
driver := r.client.Driver()
if tx := dbent.TxFromContext(ctx); tx != nil {
driver = tx.Client().Driver()
}
var result entsql.Result
if err := driver.Exec(
ctx,
`INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(user_id) DO UPDATE SET
storage_provider = excluded.storage_provider,
storage_key = excluded.storage_key,
url = excluded.url,
content_type = excluded.content_type,
byte_size = excluded.byte_size,
sha256 = excluded.sha256,
updated_at = CURRENT_TIMESTAMP`,
[]any{
userID,
input.StorageProvider,
input.StorageKey,
input.URL,
input.ContentType,
input.ByteSize,
input.SHA256,
},
&result,
); err != nil {
return nil, err
}
return &service.UserAvatar{
StorageProvider: input.StorageProvider,
StorageKey: input.StorageKey,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
return nil
func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
driver := r.client.Driver()
if tx := dbent.TxFromContext(ctx); tx != nil {
driver = tx.Client().Driver()
}
var result entsql.Result
return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
}
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
@@ -1636,6 +1761,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int
panic("unexpected UpdateConcurrency call")
}
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
return count > 0, err

View File

@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}

View File

@@ -517,7 +517,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}

View File

@@ -2,6 +2,7 @@ package handler
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -43,8 +44,24 @@ type UpdateProfileRequest struct {
type userProfileResponse struct {
dto.User
AvatarURL string `json:"avatar_url,omitempty"`
Identities service.UserIdentitySummarySet `json:"identities"`
AvatarURL string `json:"avatar_url,omitempty"`
AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
Identities service.UserIdentitySummarySet `json:"identities"`
AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
EmailBound bool `json:"email_bound"`
LinuxDoBound bool `json:"linuxdo_bound"`
OIDCBound bool `json:"oidc_bound"`
WeChatBound bool `json:"wechat_bound"`
}
type userProfileSourceContext struct {
Provider string `json:"provider,omitempty"`
Source string `json:"source,omitempty"`
}
// GetProfile handles getting user profile
@@ -335,9 +352,94 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
if base == nil {
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
User: *base,
AvatarURL: user.AvatarURL,
Identities: identities,
User: *base,
AvatarURL: user.AvatarURL,
AvatarSource: avatarSource,
UsernameSource: usernameSource,
DisplayNameSource: usernameSource,
NicknameSource: usernameSource,
ProfileSources: profileSources,
Identities: identities,
AuthBindings: bindings,
IdentityBindings: bindings,
EmailBound: identities.Email.Bound,
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
}
}
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
return map[string]service.UserIdentitySummary{
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
}
}
func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
map[string]*userProfileSourceContext,
*userProfileSourceContext,
*userProfileSourceContext,
) {
if user == nil {
return nil, nil, nil
}
thirdParty := thirdPartyIdentityProviders(identities)
var avatarSource *userProfileSourceContext
if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
usernameValue := strings.TrimSpace(user.Username)
var usernameSource *userProfileSourceContext
for _, summary := range thirdParty {
if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
usernameSource = buildUserProfileSourceContext(summary.Provider)
break
}
}
if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
profileSources := map[string]*userProfileSourceContext{}
if avatarSource != nil {
profileSources["avatar"] = avatarSource
}
if usernameSource != nil {
profileSources["username"] = usernameSource
profileSources["display_name"] = usernameSource
profileSources["nickname"] = usernameSource
}
if len(profileSources) == 0 {
return nil, avatarSource, usernameSource
}
return profileSources, avatarSource, usernameSource
}
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
if summary.Bound {
out = append(out, summary)
}
}
return out
}
func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
provider = strings.TrimSpace(provider)
if provider == "" {
return nil
}
return &userProfileSourceContext{
Provider: provider,
Source: provider,
}
}

View File

@@ -92,6 +92,12 @@ func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int6
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
@@ -230,6 +236,79 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
}
func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
gin.SetMode(gin.TestMode)
verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 21,
Email: "legacy-profile@example.com",
Username: "linuxdo-handle",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-21",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
handler.GetProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, true, resp.Data["email_bound"])
require.Equal(t, true, resp.Data["linuxdo_bound"])
require.Equal(t, false, resp.Data["oidc_bound"])
require.Equal(t, false, resp.Data["wechat_bound"])
require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, linuxdoBinding["bound"])
require.Equal(t, "linuxdo", linuxdoBinding["provider"])
identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
require.True(t, ok)
emailBinding, ok := identityBindings["email"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", avatarSource["provider"])
profileSources, ok := resp.Data["profile_sources"].(map[string]any)
require.True(t, ok)
usernameSource, ok := profileSources["username"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", usernameSource["provider"])
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
gin.SetMode(gin.TestMode)