diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 964dbb88..fe48541b 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -389,6 +389,7 @@ func TestValidateOIDCScopesMustContainOpenID(t *testing.T) { cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback" cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "profile email" + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err == nil { @@ -418,6 +419,7 @@ func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback" cfg.OIDC.Scopes = "openid email profile" cfg.OIDC.ValidateIDToken = true + cfg.OIDC.UsePKCE = true err = cfg.Validate() if err != nil { @@ -840,6 +842,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = true if err := cfg.Validate(); err != nil { t.Fatalf("Validate() unexpected error: %v", err) @@ -990,6 +993,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo client id required", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "" }, wantErr: "linuxdo_connect.client_id", @@ -998,6 +1002,7 @@ func TestValidateConfigErrors(t *testing.T) { name: "linuxdo token auth method", mutate: func(c *Config) { c.LinuxDo.Enabled = true + c.LinuxDo.UsePKCE = true c.LinuxDo.ClientID = "client" c.LinuxDo.ClientSecret = "secret" c.LinuxDo.AuthorizeURL = "https://example.com/authorize" diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 7938f3e7..0c760ee9 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -184,7 +184,7 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { TokenAuthMethod: "client_secret_post", UsePKCE: true, }) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) user, err := client.User.Create(). SetEmail("bind-cookie@example.com"). @@ -226,7 +226,7 @@ func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) { require.Equal(t, -1, accessTokenCookie.MaxAge) } -func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T) { +func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": @@ -254,7 +254,7 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin TokenAuthMethod: "client_secret_post", UsePKCE: true, }) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() existingUser, err := client.User.Create(). @@ -265,6 +265,14 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("321"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -294,7 +302,8 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail) require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"]) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) require.NotEmpty(t, completion["access_token"]) require.Nil(t, completion["error"]) @@ -328,7 +337,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test TokenAuthMethod: "client_secret_post", UsePKCE: true, }) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() existingUser, err := client.User.Create(). @@ -362,21 +371,24 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) - require.Equal(t, "adopt_existing_user_by_email", session.Intent) - require.NotNil(t, session.TargetUserID) - require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) - require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, existingUser.Email, completion["email"]) + require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) _, hasAccessToken := completion["access_token"] require.False(t, hasAccessToken) } -func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { +func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": @@ -404,7 +416,7 @@ func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresIn TokenAuthMethod: "client_secret_post", UsePKCE: true, }) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -432,9 +444,11 @@ func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresIn require.Equal(t, oauthIntentLogin, session.Intent) require.Nil(t, session.TargetUserID) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) - require.Equal(t, "invitation_required", completion["error"]) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) } func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { @@ -465,7 +479,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing. TokenAuthMethod: "client_secret_post", UsePKCE: true, }) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() currentUser, err := client.User.Create(). @@ -505,7 +519,8 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing. require.Equal(t, currentUser.ID, *session.TargetUserID) require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/settings/connections", completion["redirect"]) require.Empty(t, completion["access_token"]) require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index f6c826b7..7d7b50f4 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -298,19 +298,6 @@ func (r oauthAdoptionDecisionRequest) hasDecision() bool { return r.AdoptDisplayName != nil || r.AdoptAvatar != nil } -func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput { - input := service.PendingIdentityAdoptionDecisionInput{ - PendingAuthSessionID: sessionID, - } - if r.AdoptDisplayName != nil { - input.AdoptDisplayName = *r.AdoptDisplayName - } - if r.AdoptAvatar != nil { - input.AdoptAvatar = *r.AdoptAvatar - } - return input -} - func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { var req oauthAdoptionDecisionRequest if c == nil || c.Request == nil || c.Request.Body == nil { @@ -325,24 +312,6 @@ func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionReq return req, nil } -func persistPendingOAuthAdoptionDecision( - c *gin.Context, - svc *service.AuthPendingIdentityService, - sessionID int64, - req oauthAdoptionDecisionRequest, -) error { - if !req.hasDecision() { - return nil - } - if svc == nil { - return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") - } - if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil { - return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) - } - return nil -} - func cloneOAuthMetadata(values map[string]any) map[string]any { if len(values) == 0 { return map[string]any{} @@ -418,30 +387,6 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic return userEntity, nil } -func (h *AuthHandler) createOAuthEmailRequiredPendingSession( - c *gin.Context, - identity service.PendingAuthIdentityKey, - redirectTo string, - browserSessionKey string, - upstreamClaims map[string]any, -) error { - return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentLogin, - Identity: identity, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: upstreamClaims, - CompletionResponse: map[string]any{ - "redirect": strings.TrimSpace(redirectTo), - "step": oauthPendingChoiceStep, - "adoption_required": true, - "force_email_on_signup": true, - "email_binding_required": true, - "existing_account_bindable": true, - }, - }) -} - func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index dba26f7b..8940e37d 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -965,11 +965,11 @@ func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *tes require.NotNil(t, storedSession.ConsumedAt) } -func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) { +func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) { handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - existingUser, err := client.User.Create(). + _, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1011,18 +1011,19 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) require.Equal(t, "pending_session", payload["auth_result"]) - require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + require.Equal(t, oauthIntentLogin, payload["intent"]) require.Equal(t, "oidc", payload["provider"]) require.Equal(t, "/dashboard", payload["redirect"]) require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) + require.Equal(t, "owner@example.com", payload["email"]) require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) - require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) - require.NotNil(t, storedSession.TargetUserID) - require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.Nil(t, storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Nil(t, storedSession.ConsumedAt) @@ -1041,7 +1042,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - existingUser, err := client.User.Create(). + _, err := client.User.Create(). SetEmail(" Owner@Example.com "). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1082,12 +1083,12 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) - require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + require.Equal(t, oauthIntentLogin, payload["intent"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) - require.NotNil(t, storedSession.TargetUserID) - require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Nil(t, storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } @@ -1095,7 +1096,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") ctx := context.Background() - existingUser, err := client.User.Create(). + _, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -1137,14 +1138,13 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing var payload map[string]any require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) require.Equal(t, "pending_session", payload["auth_result"]) - require.Equal(t, "bind_login_required", payload["step"]) + require.Equal(t, oauthPendingChoiceStep, payload["step"]) require.Equal(t, "owner@example.com", payload["email"]) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) - require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) - require.NotNil(t, storedSession.TargetUserID) - require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, oauthIntentLogin, storedSession.Intent) + require.Nil(t, storedSession.TargetUserID) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) } @@ -1260,7 +1260,7 @@ func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T handler.CreateOIDCOAuthAccount(ginCtx) - require.Equal(t, http.StatusInternalServerError, recorder.Code) + require.Equal(t, http.StatusConflict, recorder.Code) userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx) require.NoError(t, err) @@ -2429,7 +2429,7 @@ func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oau &rows, ) require.NoError(t, err) - defer rows.Close() + defer func() { _ = rows.Close() }() if !rows.Next() { require.NoError(t, rows.Err()) @@ -2459,7 +2459,7 @@ func countProviderGrantRecords( &rows, ) require.NoError(t, err) - defer rows.Close() + defer func() { _ = rows.Close() }() require.True(t, rows.Next()) var count int @@ -2587,7 +2587,7 @@ func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int ); err != nil { return nil, err } - defer rows.Close() + defer func() { _ = rows.Close() }() if !rows.Next() { return nil, rows.Err() diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 5cd8e0ea..2acca18a 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -186,7 +186,7 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) { require.Equal(t, int64(84), userID) } -func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T) { +func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-login", PreferredUsername: "oidc_login", @@ -198,7 +198,7 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T defer cleanup() handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() existingUser, err := client.User.Create(). @@ -209,6 +209,14 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) + _, err = client.AuthIdentity.Create(). + SetUserID(existingUser.ID). + SetProviderType("oidc"). + SetProviderKey(cfg.IssuerURL). + SetProviderSubject("oidc-subject-login"). + SetMetadata(map[string]any{"username": "legacy-user"}). + Save(ctx) + require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -239,7 +247,8 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testing.T require.Equal(t, cfg.IssuerURL, session.ProviderKey) require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"]) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) require.NotEmpty(t, completion["access_token"]) require.Nil(t, completion["error"]) @@ -257,7 +266,7 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing defer cleanup() handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() existingUser, err := client.User.Create(). @@ -292,16 +301,19 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) - require.Equal(t, "adopt_existing_user_by_email", session.Intent) - require.NotNil(t, session.TargetUserID) - require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Nil(t, session.TargetUserID) require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/dashboard", completion["redirect"]) - require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, existingUser.Email, completion["email"]) + require.Equal(t, existingUser.Email, completion["existing_account_email"]) + require.Equal(t, true, completion["existing_account_bindable"]) + require.Equal(t, "compat_email_match", completion["choice_reason"]) _, hasAccessToken := completion["access_token"] require.False(t, hasAccessToken) } @@ -319,10 +331,10 @@ func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t * cfg.RequireEmailVerified = true handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() - existingUser, err := client.User.Create(). + _, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). @@ -345,28 +357,15 @@ func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t * handler.OIDCOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) - require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location")) + require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location")) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) - sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) - require.NotNil(t, sessionCookie) - - session, err := client.PendingAuthSession.Query(). - Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). - Only(ctx) + count, err := client.PendingAuthSession.Query().Count(ctx) require.NoError(t, err) - require.Equal(t, "adopt_existing_user_by_email", session.Intent) - require.NotNil(t, session.TargetUserID) - require.Equal(t, existingUser.ID, *session.TargetUserID) - require.Equal(t, existingUser.Email, session.ResolvedEmail) - require.Equal(t, "owner@example.com", session.UpstreamIdentityClaims["compat_email"]) - - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) - require.Equal(t, "/settings/connections", completion["redirect"]) - require.Equal(t, "bind_login_required", completion["step"]) - require.Equal(t, existingUser.Email, completion["email"]) + require.Zero(t, count) } -func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { +func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) { cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ Subject: "oidc-subject-invite", PreferredUsername: "oidc_invite", @@ -378,7 +377,7 @@ func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvit defer cleanup() handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -407,9 +406,11 @@ func TestOIDCOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvit require.Equal(t, oauthIntentLogin, session.Intent) require.Nil(t, session.TargetUserID) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) - require.Equal(t, "invitation_required", completion["error"]) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) } func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) { @@ -424,7 +425,7 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) defer cleanup() handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg) - defer client.Close() + t.Cleanup(func() { _ = client.Close() }) ctx := context.Background() currentUser, err := client.User.Create(). @@ -466,7 +467,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) require.Equal(t, cfg.IssuerURL, session.ProviderKey) require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"]) - completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.True(t, ok) require.Equal(t, "/settings/connections", completion["redirect"]) require.Empty(t, completion["access_token"]) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index ad184c46..78f5d7c2 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -1129,7 +1129,7 @@ func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code st if err != nil { return nil, fmt.Errorf("request wechat access token: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { @@ -1177,7 +1177,7 @@ func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenRespons if err != nil { return nil, fmt.Errorf("request wechat userinfo: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 140851de..937daa6d 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -155,7 +155,7 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } -func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { +func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -195,13 +195,22 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) - require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error") - require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid") - require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) - count, err := client.PendingAuthSession.Query().Count(context.Background()) + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(context.Background()) require.NoError(t, err) - require.Zero(t, count) + require.Equal(t, oauthIntentLogin, session.Intent) + require.Equal(t, "openid-123", session.ProviderSubject) + require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, oauthPendingChoiceStep, completion["step"]) + require.Equal(t, "third_party_signup", completion["choice_reason"]) } func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { @@ -669,7 +678,7 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing Where(pendingauthsession.SessionTokenEQ(sessionToken)). Only(ctx) require.NoError(t, err) - require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"]) + require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"]) body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) completeRecorder := httptest.NewRecorder() diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index ee0a2c9a..fc6a3f9e 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,22 +51,22 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - WeChatConnectEnabled bool `json:"wechat_connect_enabled"` - WeChatConnectAppID string `json:"wechat_connect_app_id"` - WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` - WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` - WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"` - WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` - WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"` - WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` - WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"` - WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` - WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` - WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` - WeChatConnectMode string `json:"wechat_connect_mode"` - WeChatConnectScopes string `json:"wechat_connect_scopes"` - WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` - WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` + WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"` + WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"` + WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"` + WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"` + WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"` + WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"` + WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"` + WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"` + WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index a6448109..4a260295 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -33,9 +33,6 @@ var ( alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { return client.TradePagePay(param) } - alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { - return client.TradePreCreate(ctx, param) - } ) // Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK. @@ -138,7 +135,7 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment param.NotifyURL = notifyURL param.ReturnURL = returnURL - payURL, err := client.TradeWapPay(param) + payURL, err := alipayTradeWapPay(client, param) if err != nil { return nil, fmt.Errorf("alipay TradeWapPay: %w", err) } diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index b25c05bd..8b3ff8ce 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -3,7 +3,6 @@ package provider import ( - "context" "errors" "net/url" "strings" @@ -136,34 +135,24 @@ func TestNewAlipay(t *testing.T) { } } -func TestCreateTradeUsesPreCreateForDesktop(t *testing.T) { - origPreCreate := alipayTradePreCreate +func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { origPagePay := alipayTradePagePay origWapPay := alipayTradeWapPay t.Cleanup(func() { - alipayTradePreCreate = origPreCreate alipayTradePagePay = origPagePay alipayTradeWapPay = origWapPay }) - preCreateCalls := 0 pagePayCalls := 0 wapPayCalls := 0 - alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { - preCreateCalls++ + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ if param.OutTradeNo != "sub2_100" { t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100") } if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" { t.Fatalf("notify_url = %q", param.NotifyURL) } - return &alipay.TradePreCreateRsp{ - OutTradeNo: "sub2_100", - QRCode: "https://qr.alipay.example.com/precreate-token", - }, nil - } - alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { - pagePayCalls++ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") } alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { @@ -172,45 +161,31 @@ func TestCreateTradeUsesPreCreateForDesktop(t *testing.T) { } provider := &Alipay{} - resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{ OrderID: "sub2_100", Amount: "88.00", Subject: "Balance recharge", - }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", false) + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") if err != nil { t.Fatalf("unexpected error: %v", err) } - if preCreateCalls != 1 { - t.Fatalf("precreate calls = %d, want 1", preCreateCalls) - } - if pagePayCalls != 0 { - t.Fatalf("page pay calls = %d, want 0", pagePayCalls) + if pagePayCalls != 1 { + t.Fatalf("page pay calls = %d, want 1", pagePayCalls) } if wapPayCalls != 0 { t.Fatalf("wap pay calls = %d, want 0", wapPayCalls) } - if resp.QRCode != "https://qr.alipay.example.com/precreate-token" { - t.Fatalf("qr_code = %q", resp.QRCode) - } - if resp.PayURL != "" { - t.Fatalf("pay_url = %q, want empty", resp.PayURL) + if resp.PayURL == "" { + t.Fatal("expected pay_url for desktop page pay") } } func TestCreateTradeUsesWapPayForMobile(t *testing.T) { - origPreCreate := alipayTradePreCreate origWapPay := alipayTradeWapPay t.Cleanup(func() { - alipayTradePreCreate = origPreCreate alipayTradeWapPay = origWapPay }) - preCreateCalls := 0 - alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { - preCreateCalls++ - return &alipay.TradePreCreateRsp{}, nil - } - wapPayCalls := 0 alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { wapPayCalls++ @@ -221,27 +196,21 @@ func TestCreateTradeUsesWapPayForMobile(t *testing.T) { } provider := &Alipay{} - resp, err := provider.createTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{ OrderID: "sub2_101", Amount: "18.00", Subject: "Balance recharge", IsMobile: true, - }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result", true) + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") if err != nil { t.Fatalf("unexpected error: %v", err) } - if preCreateCalls != 0 { - t.Fatalf("precreate calls = %d, want 0", preCreateCalls) - } if wapPayCalls != 1 { t.Fatalf("wap pay calls = %d, want 1", wapPayCalls) } if resp.PayURL == "" { t.Fatal("expected pay_url for mobile wap pay") } - if resp.QRCode != "" { - t.Fatalf("qr_code = %q, want empty", resp.QRCode) - } } func TestAlipayMerchantIdentityMetadata(t *testing.T) { diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go index 6533f24c..2d812394 100644 --- a/backend/internal/repository/user_profile_identity_repo.go +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -573,27 +573,6 @@ func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) err return err } -func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error { - if user == nil { - return nil - } - - avatar, err := r.GetUserAvatar(ctx, user.ID) - if err != nil { - return err - } - if avatar == nil { - return nil - } - - user.AvatarURL = avatar.URL - user.AvatarSource = avatar.StorageProvider - user.AvatarMIME = avatar.ContentType - user.AvatarByteSize = avatar.ByteSize - user.AvatarSHA256 = avatar.SHA256 - return nil -} - func copyMetadata(in map[string]any) map[string]any { if len(in) == 0 { return map[string]any{} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index f91fb393..ed7764cf 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -50,6 +50,7 @@ func TestAPIContracts(t *testing.T) { "data": { "id": 1, "email": "alice@example.com", + "email_bound": true, "username": "alice", "role": "user", "balance": 12.5, @@ -63,6 +64,120 @@ func TestAPIContracts(t *testing.T) { "balance_notify_threshold": null, "balance_notify_extra_emails": null, "total_recharged": 0, + "linuxdo_bound": false, + "oidc_bound": false, + "wechat_bound": false, + "identities": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "identity_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, + "auth_bindings": { + "email": { + "provider": "email", + "provider_key": "email", + "bound": true, + "bound_count": 1, + "can_bind": false, + "can_unbind": false, + "display_name": "alice@example.com", + "subject_hint": "a***e@example.com", + "note": "Primary account email is managed from the profile form." + }, + "linuxdo": { + "provider": "linuxdo", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "oidc": { + "provider": "oidc", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "wechat": { + "provider": "wechat", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + } + }, "run_mode": "standard" } }`, @@ -649,7 +764,23 @@ func TestAPIContracts(t *testing.T) { "account_quota_notify_enabled": false, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", - "account_quota_notify_emails": [] + "account_quota_notify_emails": [], + "wechat_connect_enabled": false, + "wechat_connect_app_id": "", + "wechat_connect_app_secret_configured": false, + "wechat_connect_mode": "open", + "wechat_connect_open_enabled": false, + "wechat_connect_open_app_id": "", + "wechat_connect_open_app_secret_configured": false, + "wechat_connect_mp_enabled": false, + "wechat_connect_mp_app_id": "", + "wechat_connect_mp_app_secret_configured": false, + "wechat_connect_mobile_enabled": false, + "wechat_connect_mobile_app_id": "", + "wechat_connect_mobile_app_secret_configured": false, + "wechat_connect_redirect_url": "", + "wechat_connect_frontend_redirect_url": "/auth/wechat/callback", + "wechat_connect_scopes": "snsapi_login" } }`, }, @@ -938,7 +1069,7 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 } func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { - return nil, errors.New("not implemented") + return nil, nil } func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error { @@ -953,6 +1084,10 @@ func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64 return nil, nil } +func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + return nil +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index a485af16..110c9008 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "encoding/json" "errors" "fmt" @@ -20,8 +19,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" - - entsql "entgo.io/ent/dialect/sql" ) // AdminService interface defines admin management operations @@ -999,17 +996,6 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 return buildAdminBoundAuthIdentity(identity, channel), nil } -func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { - if s == nil || s.entClient == nil { - return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") - } - driver, ok := s.entClient.Driver().(*entsql.Driver) - if !ok || driver.DB() == nil { - return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") - } - return driver.DB(), nil -} - func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { if input == nil { return nil diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go index 4f3d5f53..7001ee18 100644 --- a/backend/internal/service/auth_pending_identity_service.go +++ b/backend/internal/service/auth_pending_identity_service.go @@ -11,8 +11,8 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" - dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" entsql "entgo.io/ent/dialect/sql" diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 4b9b8313..6d61894b 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -679,13 +679,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { - if s.settingService == nil { - return - } - s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting") -} - func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -863,7 +856,7 @@ func (s *AuthService) hasProviderGrantRecord( if err != nil { return false, err } - defer rows.Close() + defer func() { _ = rows.Close() }() return rows.Next(), rows.Err() } @@ -917,7 +910,7 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s DoNothing(). Exec(ctx); err != nil { if isSQLNoRowsError(err) { - err = nil + return nil, false } } if err != nil { diff --git a/backend/internal/service/payment_config_limits.go b/backend/internal/service/payment_config_limits.go index 67979d5b..57a4108f 100644 --- a/backend/internal/service/payment_config_limits.go +++ b/backend/internal/service/payment_config_limits.go @@ -53,40 +53,6 @@ func pcApplyEnabledVisibleMethodInstances(typeInstances map[string][]*dbent.Paym return filtered } -func pcApplyVisibleMethodRouting(typeInstances map[string][]*dbent.PaymentProviderInstance, vals map[string]string, available map[string]bool) map[string][]*dbent.PaymentProviderInstance { - if len(typeInstances) == 0 { - return typeInstances - } - - filtered := make(map[string][]*dbent.PaymentProviderInstance, len(typeInstances)) - for paymentType, instances := range typeInstances { - visibleMethod := NormalizeVisibleMethod(paymentType) - switch visibleMethod { - case payment.TypeAlipay, payment.TypeWxpay: - if !visibleMethodShouldBeExposed(visibleMethod, vals, available) { - continue - } - targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[visibleMethodSourceSettingKey(visibleMethod)]) - if !ok { - continue - } - matching := make([]*dbent.PaymentProviderInstance, 0, len(instances)) - for _, inst := range instances { - if inst.ProviderKey == targetProviderKey { - matching = append(matching, inst) - } - } - if len(matching) == 0 { - continue - } - filtered[paymentType] = matching - default: - filtered[paymentType] = instances - } - } - return filtered -} - // GetMethodLimits returns per-payment-type limits from enabled provider instances. func (s *PaymentConfigService) GetMethodLimits(ctx context.Context, types []string) ([]MethodLimits, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index b8c1b15b..2c0f8206 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -210,9 +210,15 @@ func TestCreateProviderInstanceRejectsConflictingVisibleMethodEnablement(t *test } _, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ - ProviderKey: "easypay", - Name: "EasyPay Alipay", - Config: map[string]string{"pid": "1001"}, + ProviderKey: "easypay", + Name: "EasyPay Alipay", + Config: map[string]string{ + "pid": "1001", + "pkey": "pkey-1001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, SupportedTypes: []string{"alipay"}, Enabled: true, }) @@ -240,9 +246,15 @@ func TestUpdateProviderInstanceRejectsEnablingConflictingVisibleMethodProvider(t } existing, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ - ProviderKey: "easypay", - Name: "EasyPay WeChat", - Config: map[string]string{"pid": "2001"}, + ProviderKey: "easypay", + Name: "EasyPay WeChat", + Config: map[string]string{ + "pid": "2001", + "pkey": "pkey-2001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, SupportedTypes: []string{"wxpay"}, Enabled: true, }) @@ -276,9 +288,15 @@ func TestUpdateProviderInstancePersistsEnabledAndSupportedTypes(t *testing.T) { } instance, err := svc.CreateProviderInstance(ctx, CreateProviderInstanceRequest{ - ProviderKey: "easypay", - Name: "EasyPay", - Config: map[string]string{"pid": "3001"}, + ProviderKey: "easypay", + Name: "EasyPay", + Config: map[string]string{ + "pid": "3001", + "pkey": "pkey-3001", + "apiBase": "https://pay.example.com", + "notifyUrl": "https://merchant.example.com/notify", + "returnUrl": "https://merchant.example.com/return", + }, SupportedTypes: []string{"alipay"}, Enabled: false, }) diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go index 1538ecbf..6e8acccb 100644 --- a/backend/internal/service/payment_resume_service.go +++ b/backend/internal/service/payment_resume_service.go @@ -23,8 +23,6 @@ const ( PaymentSourceHostedRedirect = "hosted_redirect" PaymentSourceWechatInAppResume = "wechat_in_app_resume" - paymentResumeFallbackSigningKey = "sub2api-payment-resume" - SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go index 275b4a94..78b6bba3 100644 --- a/backend/internal/service/payment_resume_service_test.go +++ b/backend/internal/service/payment_resume_service_test.go @@ -413,7 +413,7 @@ func mustCreateFallbackSignedToken(t *testing.T, claims any) string { t.Fatalf("marshal claims: %v", err) } encodedPayload := base64.RawURLEncoding.EncodeToString(payload) - mac := hmac.New(sha256.New, []byte(paymentResumeFallbackSigningKey)) + mac := hmac.New(sha256.New, []byte("sub2api-payment-resume")) _, _ = mac.Write([]byte(encodedPayload)) signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) return encodedPayload + "." + signature diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index f12cf691..0f3efa1f 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -4,7 +4,11 @@ package service import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/json" + "encoding/pem" "strconv" "testing" "time" @@ -52,6 +56,28 @@ func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalanc return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) } +func encryptValidWebhookWxpayConfig(t *testing.T, suffix string) string { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + return encryptWebhookProviderConfig(t, map[string]string{ + "appId": "wx-app-" + suffix, + "mchId": "mch-" + suffix, + "privateKey": string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})), + "apiV3Key": webhookProviderTestEncryptionKey, + "publicKey": string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER})), + "publicKeyId": "public-key-id-" + suffix, + "certSerial": "cert-serial-" + suffix, + }) +} + func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) @@ -275,24 +301,8 @@ func TestGetOrderProviderInstanceRejectsMissingSnapshotInstanceWithoutLegacyFall func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) - wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{ - "appId": "wx-app-a", - "mchId": "mch-a", - "privateKey": "private-key-a", - "apiV3Key": webhookProviderTestEncryptionKey, - "publicKey": "public-key-a", - "publicKeyId": "public-key-id-a", - "certSerial": "cert-serial-a", - }) - wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{ - "appId": "wx-app-b", - "mchId": "mch-b", - "privateKey": "private-key-b", - "apiV3Key": webhookProviderTestEncryptionKey, - "publicKey": "public-key-b", - "publicKeyId": "public-key-id-b", - "certSerial": "cert-serial-b", - }) + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "b") _, err := client.PaymentProviderInstance.Create(). SetProviderKey(payment.TypeWxpay). SetName("wxpay-a"). @@ -442,24 +452,8 @@ func TestGetWebhookProviderUsesProviderSnapshotBeforeWxpayFallback(t *testing.T) Save(ctx) require.NoError(t, err) - wxpayConfigA := encryptWebhookProviderConfig(t, map[string]string{ - "appId": "wx-app-snapshot-a", - "mchId": "mch-snapshot-a", - "privateKey": "private-key-snapshot-a", - "apiV3Key": webhookProviderTestEncryptionKey, - "publicKey": "public-key-snapshot-a", - "publicKeyId": "public-key-id-snapshot-a", - "certSerial": "cert-serial-snapshot-a", - }) - wxpayConfigB := encryptWebhookProviderConfig(t, map[string]string{ - "appId": "wx-app-snapshot-b", - "mchId": "mch-snapshot-b", - "privateKey": "private-key-snapshot-b", - "apiV3Key": webhookProviderTestEncryptionKey, - "publicKey": "public-key-snapshot-b", - "publicKeyId": "public-key-id-snapshot-b", - "certSerial": "cert-serial-snapshot-b", - }) + wxpayConfigA := encryptValidWebhookWxpayConfig(t, "snapshot-a") + wxpayConfigB := encryptValidWebhookWxpayConfig(t, "snapshot-b") instA, err := client.PaymentProviderInstance.Create(). SetProviderKey(payment.TypeWxpay). SetName("wxpay-snapshot-a"). diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index d0a57311..bc444af5 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -183,10 +183,6 @@ type UpsertUserAvatarInput struct { SHA256 string } -type userAuthIdentityReader interface { - ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) -} - type userProfileIdentityTxRunner interface { WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error } @@ -812,17 +808,6 @@ func maskOpaqueIdentity(value string) string { } } -func cloneAnyMap(values map[string]any) map[string]any { - if len(values) == 0 { - return map[string]any{} - } - cloned := make(map[string]any, len(values)) - for key, value := range values { - cloned[key] = value - } - return cloned -} - // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {