//go:build unit package handler import ( "bytes" "context" "database/sql" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" _ "modernc.org/sqlite" ) func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { gin.SetMode(gin.TestMode) handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ service.SettingKeyWeChatConnectEnabled: "true", service.SettingKeyWeChatConnectAppID: "wx-open-app", service.SettingKeyWeChatConnectAppSecret: "wx-open-secret", service.SettingKeyWeChatConnectMode: "open", service.SettingKeyWeChatConnectScopes: "snsapi_login", service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", }) defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) c.Request.Host = "api.example.com" handler.WeChatOAuthStart(c) require.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") require.NotEmpty(t, location) require.Contains(t, location, "open.weixin.qq.com") require.Contains(t, location, "appid=wx-open-app") require.Contains(t, location, "scope=snsapi_login") cookies := recorder.Result().Cookies() require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) } func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) { gin.SetMode(gin.TestMode) handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ service.SettingKeyWeChatConnectEnabled: "true", service.SettingKeyWeChatConnectAppID: "wx-shared-app", service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret", service.SettingKeyWeChatConnectMode: "mp", service.SettingKeyWeChatConnectScopes: "snsapi_base", service.SettingKeyWeChatConnectOpenEnabled: "true", service.SettingKeyWeChatConnectMPEnabled: "true", service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", }) defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) c.Request.Host = "api.example.com" handler.WeChatOAuthStart(c) require.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") require.NotEmpty(t, location) require.Contains(t, location, "open.weixin.qq.com") require.Contains(t, location, "connect/qrconnect") require.Contains(t, location, "scope=snsapi_login") } func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) ctx := context.Background() session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) require.Equal(t, "wechat", session.ProviderType) require.Equal(t, "wechat-main", session.ProviderKey) require.Equal(t, "union-456", session.ProviderSubject) require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(context.Background()) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) require.Equal(t, "openid-123", session.ProviderSubject) require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail) completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, "third_party_signup", completion["choice_reason"]) } func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) defer client.Close() ctx := context.Background() existingUser, err := client.User.Create(). SetEmail(wechatSyntheticEmail("union-456")). SetUsername("wechat-existing-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) _, err = client.AuthIdentity.Create(). SetUserID(existingUser.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthProviderKey). SetProviderSubject("union-456"). SetMetadata(map[string]any{"username": "wechat-existing-user"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) require.Equal(t, oauthIntentLogin, session.Intent) require.NotNil(t, session.TargetUserID) require.Equal(t, existingUser.ID, *session.TargetUserID) require.Equal(t, existingUser.Email, session.ResolvedEmail) completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.Equal(t, "/dashboard", completion["redirect"]) _, hasAccessToken := completion["access_token"] require.False(t, hasAccessToken) _, hasRefreshToken := completion["refresh_token"] require.False(t, hasRefreshToken) } func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`)) return } http.NotFound(w, r) })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) defer client.Close() handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123")) req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat")) req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`)) req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base")) c.Request = req handler.WeChatPaymentOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) location := recorder.Header().Get("Location") parsed, err := url.Parse(location) require.NoError(t, err) fragment, err := url.ParseQuery(parsed.Fragment) require.NoError(t, err) require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect")) require.NotEmpty(t, fragment.Get("wechat_resume_token")) require.Empty(t, fragment.Get("openid")) require.Empty(t, fragment.Get("payment_type")) require.Empty(t, fragment.Get("amount")) require.Empty(t, fragment.Get("order_type")) require.Empty(t, fragment.Get("plan_id")) claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token")) require.NoError(t, err) require.Equal(t, "openid-123", claims.OpenID) require.Equal(t, payment.TypeWxpay, claims.PaymentType) require.Equal(t, "12.5", claims.Amount) require.Equal(t, payment.OrderTypeSubscription, claims.OrderType) require.EqualValues(t, 7, claims.PlanID) require.Equal(t, "/purchase?from=wechat", claims.RedirectTo) } func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { testCases := []struct { name string mode string appID string appSecret string openID string }{ { name: "open", mode: "open", appID: "wx-open-app", appSecret: "wx-open-secret", openID: "openid-open-123", }, { name: "mp", mode: "mp", appID: "wx-mp-app", appSecret: "wx-mp-secret", openID: "openid-mp-123", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback")) defer client.Close() currentUser, err := client.User.Create(). SetEmail("current@example.com"). SetUsername("current-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(context.Background()) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode)) req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(context.Background()) require.NoError(t, err) require.Equal(t, wechatOAuthIntentBind, session.Intent) require.NotNil(t, session.TargetUserID) require.Equal(t, currentUser.ID, *session.TargetUserID) require.Equal(t, currentUser.Email, session.ResolvedEmail) require.Equal(t, "union-456", session.ProviderSubject) require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"]) require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"]) require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"]) require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"]) require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"]) completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) require.Equal(t, "/dashboard", completionResponse["redirect"]) _, hasAccessToken := completionResponse["access_token"] require.False(t, hasAccessToken) }) } } func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() owner, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) currentUser, err := client.User.Create(). SetEmail("current@example.com"). SetUsername("current"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) _, err = client.AuthIdentity.Create(). SetUserID(owner.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthProviderKey). SetProviderSubject("union-456"). SetMetadata(map[string]any{"unionid": "union-456"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") count, err := client.PendingAuthSession.Query().Count(ctx) require.NoError(t, err) require.Zero(t, count) } func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() owner, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) currentUser, err := client.User.Create(). SetEmail("current@example.com"). SetUsername("current"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) ownerIdentity, err := client.AuthIdentity.Create(). SetUserID(owner.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthProviderKey). SetProviderSubject("union-owner"). SetMetadata(map[string]any{"unionid": "union-owner"}). Save(ctx) require.NoError(t, err) _, err = client.AuthIdentityChannel.Create(). SetIdentityID(ownerIdentity.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthProviderKey). SetChannel("open"). SetChannelAppID("wx-open-app"). SetChannelSubject("openid-123"). SetMetadata(map[string]any{"openid": "openid-123"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT") count, err := client.PendingAuthSession.Query().Count(ctx) require.NoError(t, err) require.Zero(t, count) } func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() owner, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) currentUser, err := client.User.Create(). SetEmail("current@example.com"). SetUsername("current"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) _, err = client.AuthIdentity.Create(). SetUserID(owner.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthLegacyProviderKey). SetProviderSubject("union-456"). SetMetadata(map[string]any{"unionid": "union-456"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") count, err := client.PendingAuthSession.Query().Count(ctx) require.NoError(t, err) require.Zero(t, count) } func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, true) defer client.Close() ctx := context.Background() redeemRepo := repository.NewRedeemCodeRepository(client) require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ Code: "invite-1", Type: service.RedeemTypeInvitation, Status: service.StatusUnused, })) callbackRecorder := httptest.NewRecorder() callbackCtx, _ := gin.CreateTestContext(callbackRecorder) callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) callbackReq.Host = "api.example.com" callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) callbackCtx.Request = callbackReq handler.WeChatOAuthCallback(callbackCtx) require.Equal(t, http.StatusFound, callbackRecorder.Code) require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) pendingSession, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(sessionToken)). Only(ctx) require.NoError(t, err) require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"]) body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) completeRecorder := httptest.NewRecorder() completeCtx, _ := gin.CreateTestContext(completeRecorder) completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) completeReq.Header.Set("Content-Type", "application/json") completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) completeCtx.Request = completeReq handler.CompleteWeChatOAuthRegistration(completeCtx) require.Equal(t, http.StatusOK, completeRecorder.Code) responseData := decodeJSONBody(t, completeRecorder) require.Equal(t, "pending_session", responseData["auth_result"]) require.Equal(t, oauthPendingChoiceStep, responseData["step"]) require.Equal(t, true, responseData["adoption_required"]) require.Empty(t, responseData["access_token"]) consumed, err := client.PendingAuthSession.Query(). Where(pendingauthsession.IDEQ(pendingSession.ID)). Only(ctx) require.NoError(t, err) require.Nil(t, consumed.ConsumedAt) userCount, err := client.User.Query().Count(ctx) require.NoError(t, err) require.Zero(t, userCount) identityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ("wechat-main"), authidentity.ProviderSubjectEQ("union-456"), ). Count(ctx) require.NoError(t, err) require.Zero(t, identityCount) channelCount, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderKeyEQ("wechat-main"), authidentitychannel.ChannelEQ("open"), authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelSubjectEQ("openid-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, channelCount) decisionCount, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). Count(ctx) require.NoError(t, err) require.Zero(t, decisionCount) } func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() legacyUser, err := client.User.Create(). SetEmail("legacy@example.com"). SetUsername("legacy-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) legacyIdentity, err := client.AuthIdentity.Create(). SetUserID(legacyUser.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthProviderKey). SetProviderSubject("openid-123"). SetMetadata(map[string]any{"openid": "openid-123"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) require.NotNil(t, session.TargetUserID) require.Equal(t, legacyUser.ID, *session.TargetUserID) require.Equal(t, legacyUser.Email, session.ResolvedEmail) repairedIdentity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ(wechatOAuthProviderKey), authidentity.ProviderSubjectEQ("union-456"), ). Only(ctx) require.NoError(t, err) require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) require.Equal(t, legacyUser.ID, repairedIdentity.UserID) openIDIdentityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ(wechatOAuthProviderKey), authidentity.ProviderSubjectEQ("openid-123"), ). Count(ctx) require.NoError(t, err) require.Zero(t, openIDIdentityCount) channel, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), authidentitychannel.ChannelEQ("open"), authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelSubjectEQ("openid-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, repairedIdentity.ID, channel.IdentityID) } func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) { handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() existingUser, err := client.User.Create(). SetEmail("owner@example.com"). SetUsername("owner-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) session, err := client.PendingAuthSession.Create(). SetSessionToken("wechat-complete-invalid-session"). SetIntent("adopt_existing_user_by_email"). SetProviderType("wechat"). SetProviderKey("wechat-main"). SetProviderSubject("union-invalid-1"). SetTargetUserID(existingUser.ID). SetResolvedEmail(existingUser.Email). SetBrowserSessionKey("wechat-invalid-browser"). SetUpstreamIdentityClaims(map[string]any{ "username": "wechat_user", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "step": "bind_login_required", }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) recorder := httptest.NewRecorder() completeCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")}) completeCtx.Request = req handler.CompleteWeChatOAuthRegistration(completeCtx) require.Equal(t, http.StatusBadRequest, recorder.Code) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) { handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() session, err := client.PendingAuthSession.Create(). SetSessionToken("wechat-complete-choice-session"). SetIntent("login"). SetProviderType("wechat"). SetProviderKey("wechat-main"). SetProviderSubject("wechat-choice-subject-1"). SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid"). SetBrowserSessionKey("wechat-choice-browser"). SetUpstreamIdentityClaims(map[string]any{ "username": "wechat_user", }). SetLocalFlowState(map[string]any{ oauthCompletionResponseKey: map[string]any{ "step": oauthPendingChoiceStep, "redirect": "/dashboard", "email": "fresh@example.com", "resolved_email": "fresh@example.com", "force_email_on_signup": true, }, }). SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). Save(ctx) require.NoError(t, err) body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`) recorder := httptest.NewRecorder() completeCtx, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) req.Header.Set("Content-Type", "application/json") req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")}) completeCtx.Request = req handler.CompleteWeChatOAuthRegistration(completeCtx) require.Equal(t, http.StatusOK, recorder.Code) responseData := decodeJSONBody(t, recorder) require.Equal(t, "pending_session", responseData["auth_result"]) require.Equal(t, oauthPendingChoiceStep, responseData["step"]) require.Equal(t, true, responseData["force_email_on_signup"]) require.Empty(t, responseData["access_token"]) userCount, err := client.User.Query().Count(ctx) require.NoError(t, err) require.Zero(t, userCount) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) require.NoError(t, err) require.Nil(t, storedSession.ConsumedAt) } func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL wechatOAuthUserInfoURL = originalUserInfoURL }) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) case strings.Contains(r.URL.Path, "/sns/userinfo"): w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) default: http.NotFound(w, r) } })) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" handler, client := newWeChatOAuthTestHandler(t, false) defer client.Close() ctx := context.Background() legacyUser, err := client.User.Create(). SetEmail("legacy@example.com"). SetUsername("legacy-user"). SetPasswordHash("hash"). SetRole(service.RoleUser). SetStatus(service.StatusActive). Save(ctx) require.NoError(t, err) legacyIdentity, err := client.AuthIdentity.Create(). SetUserID(legacyUser.ID). SetProviderType("wechat"). SetProviderKey(wechatOAuthLegacyProviderKey). SetProviderSubject("union-456"). SetMetadata(map[string]any{"unionid": "union-456"}). Save(ctx) require.NoError(t, err) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) req.Host = "api.example.com" req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) c.Request = req handler.WeChatOAuthCallback(c) require.Equal(t, http.StatusFound, recorder.Code) require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) require.NotNil(t, sessionCookie) session, err := client.PendingAuthSession.Query(). Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). Only(ctx) require.NoError(t, err) require.NotNil(t, session.TargetUserID) require.Equal(t, legacyUser.ID, *session.TargetUserID) require.Equal(t, legacyUser.Email, session.ResolvedEmail) repairedIdentity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ(wechatOAuthProviderKey), authidentity.ProviderSubjectEQ("union-456"), ). Only(ctx) require.NoError(t, err) require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) require.Equal(t, legacyUser.ID, repairedIdentity.UserID) legacyIdentityCount, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), authidentity.ProviderSubjectEQ("union-456"), ). Count(ctx) require.NoError(t, err) require.Zero(t, legacyIdentityCount) channel, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), authidentitychannel.ChannelEQ("open"), authidentitychannel.ChannelAppIDEQ("wx-open-app"), authidentitychannel.ChannelSubjectEQ("openid-123"), ). Only(ctx) require.NoError(t, err) require.Equal(t, repairedIdentity.ID, channel.IdentityID) } func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil) } func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string { return map[string]string{ service.SettingKeyWeChatConnectEnabled: "true", service.SettingKeyWeChatConnectAppID: appID, service.SettingKeyWeChatConnectAppSecret: secret, service.SettingKeyWeChatConnectMode: mode, service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode), service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect, } } func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) { t.Helper() db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) _, err = db.Exec("PRAGMA foreign_keys = ON") require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) userRepo := &oauthPendingFlowUserRepo{client: client} redeemRepo := repository.NewRedeemCodeRepository(client) cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", ExpireHour: 1, AccessTokenExpireMinutes: 60, RefreshTokenExpireDays: 7, }, Default: config.DefaultConfig{ UserBalance: 0, UserConcurrency: 1, }, } values := map[string]string{ service.SettingKeyRegistrationEnabled: "true", service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), } for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") { values[key] = value } for key, value := range extraSettings { values[key] = value } settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg) authSvc := service.NewAuthService( client, userRepo, redeemRepo, &wechatOAuthRefreshTokenCacheStub{}, cfg, settingSvc, nil, nil, nil, nil, nil, ) return &AuthHandler{ authService: authSvc, settingSvc: settingSvc, cfg: cfg, }, client } func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { t.Helper() parsed, err := url.Parse(location) require.NoError(t, err) fragment, err := url.ParseQuery(parsed.Fragment) require.NoError(t, err) require.Equal(t, errorCode, fragment.Get("error")) require.Equal(t, errorMessage, fragment.Get("error_message")) } type wechatOAuthSettingRepoStub struct { values map[string]string } func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { return nil, service.ErrSettingNotFound } func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { value, ok := s.values[key] if !ok { return "", service.ErrSettingNotFound } return value, nil } func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { return nil } func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { result := make(map[string]string, len(keys)) for _, key := range keys { if value, ok := s.values[key]; ok { result[key] = value } } return result, nil } func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil } func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { result := make(map[string]string, len(s.values)) for key, value := range s.values { result[key] = value } return result, nil } func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { return nil } type wechatOAuthRefreshTokenCacheStub struct{} func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { return nil, service.ErrRefreshTokenNotFound } func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { return nil } func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { return nil, nil } func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { return nil, nil } func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { return false, nil }