diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go new file mode 100644 index 00000000..dab95e29 --- /dev/null +++ b/backend/internal/handler/auth_current_user_test.go @@ -0,0 +1,78 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 31, + Email: "me@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-31", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + + handler := &AuthHandler{ + userService: service.NewUserService(repo, nil, nil, nil), + } + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31}) + + handler.GetCurrentUser(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + + _, hasAvatarSource := resp.Data["avatar_source"] + require.False(t, hasAvatarSource) + _, hasProfileSources := resp.Data["profile_sources"] + require.False(t, hasProfileSources) +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index b984a436..76ca153d 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -348,8 +348,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { return } + identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user) + if err != nil { + response.ErrorFrom(c, err) + return + } + type UserResponse struct { - *dto.User + userProfileResponse RunMode string `json:"run_mode"` } @@ -358,7 +364,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { runMode = h.cfg.RunMode } - response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) + response.Success(c, UserResponse{ + userProfileResponse: userProfileResponseFromService(user, identities), + RunMode: runMode, + }) } // ValidatePromoCodeRequest 验证优惠码请求 diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 94186858..461810f1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -848,6 +848,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision } } +func shouldSkipAvatarAdoption(err error) bool { + return errors.Is(err, service.ErrAvatarInvalid) || + errors.Is(err, service.ErrAvatarTooLarge) || + errors.Is(err, service.ErrAvatarNotImage) +} + func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, @@ -885,6 +891,14 @@ func applyPendingOAuthBinding( if decision != nil && decision.AdoptAvatar { adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") } + shouldAdoptAvatar := false + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { + if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil { + shouldAdoptAvatar = true + } else if !shouldSkipAvatarAdoption(err) { + return err + } + } tx, err := client.Tx(ctx) if err != nil { @@ -913,7 +927,7 @@ func applyPendingOAuthBinding( if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { metadata["display_name"] = adoptedDisplayName } - if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { + if shouldAdoptAvatar { metadata["avatar_url"] = adoptedAvatarURL } @@ -939,7 +953,7 @@ func applyPendingOAuthBinding( } } - if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil { + if shouldAdoptAvatar && userService != nil { if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil { return err } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 2521186e..d29e4b88 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -173,6 +173,78 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio require.NotNil(t, consumed.ConsumedAt) } +func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("invalid-avatar@example.com"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-invalid-avatar-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("invalid-avatar-123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-invalid-avatar-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "/avatars/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")}) + ginCtx.Request = req + + handler.ExchangePendingOAuthCompletion(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("invalid-avatar-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + _, hasAdoptedAvatar := identity.Metadata["avatar_url"] + require.False(t, hasAdoptedAvatar) + + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.Nil(t, avatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index b1ade5c0..843b0bd9 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -2,7 +2,6 @@ package handler import ( "context" - "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -353,22 +352,16 @@ func userProfileResponseFromService(user *service.User, identities service.UserI return userProfileResponse{} } bindings := userProfileBindingMap(identities) - profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities) return userProfileResponse{ - User: *base, - AvatarURL: user.AvatarURL, - AvatarSource: avatarSource, - UsernameSource: usernameSource, - DisplayNameSource: usernameSource, - NicknameSource: usernameSource, - ProfileSources: profileSources, - Identities: identities, - AuthBindings: bindings, - IdentityBindings: bindings, - EmailBound: identities.Email.Bound, - LinuxDoBound: identities.LinuxDo.Bound, - OIDCBound: identities.OIDC.Bound, - WeChatBound: identities.WeChat.Bound, + User: *base, + AvatarURL: user.AvatarURL, + Identities: identities, + AuthBindings: bindings, + IdentityBindings: bindings, + EmailBound: identities.Email.Bound, + LinuxDoBound: identities.LinuxDo.Bound, + OIDCBound: identities.OIDC.Bound, + WeChatBound: identities.WeChat.Bound, } } @@ -380,66 +373,3 @@ func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string "wechat": identities.WeChat, } } - -func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) ( - map[string]*userProfileSourceContext, - *userProfileSourceContext, - *userProfileSourceContext, -) { - if user == nil { - return nil, nil, nil - } - - thirdParty := thirdPartyIdentityProviders(identities) - var avatarSource *userProfileSourceContext - if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 { - avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider) - } - - usernameValue := strings.TrimSpace(user.Username) - var usernameSource *userProfileSourceContext - for _, summary := range thirdParty { - if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) { - usernameSource = buildUserProfileSourceContext(summary.Provider) - break - } - } - if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 { - usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider) - } - - profileSources := map[string]*userProfileSourceContext{} - if avatarSource != nil { - profileSources["avatar"] = avatarSource - } - if usernameSource != nil { - profileSources["username"] = usernameSource - profileSources["display_name"] = usernameSource - profileSources["nickname"] = usernameSource - } - if len(profileSources) == 0 { - return nil, avatarSource, usernameSource - } - return profileSources, avatarSource, usernameSource -} - -func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { - out := make([]service.UserIdentitySummary, 0, 3) - for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { - if summary.Bound { - out = append(out, summary) - } - } - return out -} - -func buildUserProfileSourceContext(provider string) *userProfileSourceContext { - provider = strings.TrimSpace(provider) - if provider == "" { - return nil - } - return &userProfileSourceContext{ - Provider: provider, - Source: provider, - } -} diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 1216f9c4..7c6460e8 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -298,15 +298,10 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { require.True(t, ok) require.Equal(t, true, emailBinding["bound"]) - avatarSource, ok := resp.Data["avatar_source"].(map[string]any) - require.True(t, ok) - require.Equal(t, "linuxdo", avatarSource["provider"]) - - profileSources, ok := resp.Data["profile_sources"].(map[string]any) - require.True(t, ok) - usernameSource, ok := profileSources["username"].(map[string]any) - require.True(t, ok) - require.Equal(t, "linuxdo", usernameSource["provider"]) + _, hasAvatarSource := resp.Data["avatar_source"] + require.False(t, hasAvatarSource) + _, hasProfileSources := resp.Data["profile_sources"] + require.False(t, hasProfileSources) } func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c106a3f5..cd1bc2bb 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -403,6 +403,11 @@ func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { }, nil } +func ValidateUserAvatar(raw string) error { + _, err := normalizeUserAvatarInput(raw) + return err +} + func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { body := strings.TrimPrefix(raw, "data:") meta, encoded, ok := strings.Cut(body, ",")