Fix profile refresh identity compatibility

This commit is contained in:
IanShaw027
2026-04-21 00:42:55 +08:00
parent 030da8c2f6
commit e4fe9fae2a
7 changed files with 195 additions and 92 deletions

View File

@@ -0,0 +1,78 @@
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
gin.SetMode(gin.TestMode)
verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 31,
Email: "me@example.com",
Username: "linuxdo-handle",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-31",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
},
},
},
}
handler := &AuthHandler{
userService: service.NewUserService(repo, nil, nil, nil),
}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
handler.GetCurrentUser(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, true, resp.Data["email_bound"])
require.Equal(t, true, resp.Data["linuxdo_bound"])
require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, linuxdoBinding["bound"])
_, hasAvatarSource := resp.Data["avatar_source"]
require.False(t, hasAvatarSource)
_, hasProfileSources := resp.Data["profile_sources"]
require.False(t, hasProfileSources)
}

View File

@@ -348,8 +348,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
if err != nil {
response.ErrorFrom(c, err)
return
}
type UserResponse struct {
*dto.User
userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -358,7 +364,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
response.Success(c, UserResponse{
userProfileResponse: userProfileResponseFromService(user, identities),
RunMode: runMode,
})
}
// ValidatePromoCodeRequest 验证优惠码请求

View File

@@ -848,6 +848,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
}
}
func shouldSkipAvatarAdoption(err error) bool {
return errors.Is(err, service.ErrAvatarInvalid) ||
errors.Is(err, service.ErrAvatarTooLarge) ||
errors.Is(err, service.ErrAvatarNotImage)
}
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
@@ -885,6 +891,14 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
shouldAdoptAvatar := false
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
shouldAdoptAvatar = true
} else if !shouldSkipAvatarAdoption(err) {
return err
}
}
tx, err := client.Tx(ctx)
if err != nil {
@@ -913,7 +927,7 @@ func applyPendingOAuthBinding(
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
if shouldAdoptAvatar {
metadata["avatar_url"] = adoptedAvatarURL
}
@@ -939,7 +953,7 @@ func applyPendingOAuthBinding(
}
}
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil {
if shouldAdoptAvatar && userService != nil {
if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil {
return err
}

View File

@@ -173,6 +173,78 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio
require.NotNil(t, consumed.ConsumedAt)
}
func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("invalid-avatar@example.com").
SetUsername("legacy-name").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-invalid-avatar-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("invalid-avatar-123").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("browser-invalid-avatar-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "Alice Example",
"suggested_avatar_url": "/avatars/alice.png",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
ginCtx.Request = req
handler.ExchangePendingOAuthCompletion(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("invalid-avatar-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
_, hasAdoptedAvatar := identity.Metadata["avatar_url"]
require.False(t, hasAdoptedAvatar)
avatar := loadUserAvatarRecord(t, client, userEntity.ID)
require.Nil(t, avatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()

View File

@@ -2,7 +2,6 @@ package handler
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -353,22 +352,16 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
return userProfileResponse{}
}
bindings := userProfileBindingMap(identities)
profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
return userProfileResponse{
User: *base,
AvatarURL: user.AvatarURL,
AvatarSource: avatarSource,
UsernameSource: usernameSource,
DisplayNameSource: usernameSource,
NicknameSource: usernameSource,
ProfileSources: profileSources,
Identities: identities,
AuthBindings: bindings,
IdentityBindings: bindings,
EmailBound: identities.Email.Bound,
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
User: *base,
AvatarURL: user.AvatarURL,
Identities: identities,
AuthBindings: bindings,
IdentityBindings: bindings,
EmailBound: identities.Email.Bound,
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
}
}
@@ -380,66 +373,3 @@ func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string
"wechat": identities.WeChat,
}
}
func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
map[string]*userProfileSourceContext,
*userProfileSourceContext,
*userProfileSourceContext,
) {
if user == nil {
return nil, nil, nil
}
thirdParty := thirdPartyIdentityProviders(identities)
var avatarSource *userProfileSourceContext
if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
usernameValue := strings.TrimSpace(user.Username)
var usernameSource *userProfileSourceContext
for _, summary := range thirdParty {
if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
usernameSource = buildUserProfileSourceContext(summary.Provider)
break
}
}
if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
profileSources := map[string]*userProfileSourceContext{}
if avatarSource != nil {
profileSources["avatar"] = avatarSource
}
if usernameSource != nil {
profileSources["username"] = usernameSource
profileSources["display_name"] = usernameSource
profileSources["nickname"] = usernameSource
}
if len(profileSources) == 0 {
return nil, avatarSource, usernameSource
}
return profileSources, avatarSource, usernameSource
}
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
if summary.Bound {
out = append(out, summary)
}
}
return out
}
func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
provider = strings.TrimSpace(provider)
if provider == "" {
return nil
}
return &userProfileSourceContext{
Provider: provider,
Source: provider,
}
}

View File

@@ -298,15 +298,10 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.True(t, ok)
require.Equal(t, true, emailBinding["bound"])
avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", avatarSource["provider"])
profileSources, ok := resp.Data["profile_sources"].(map[string]any)
require.True(t, ok)
usernameSource, ok := profileSources["username"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", usernameSource["provider"])
_, hasAvatarSource := resp.Data["avatar_source"]
require.False(t, hasAvatarSource)
_, hasProfileSources := resp.Data["profile_sources"]
require.False(t, hasProfileSources)
}
func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {