Files
sub2api/backend/internal/handler/auth_wechat_oauth_test.go

1315 lines
51 KiB
Go

//go:build unit
package handler
import (
"bytes"
"context"
"database/sql"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
service.SettingKeyWeChatConnectEnabled: "true",
service.SettingKeyWeChatConnectAppID: "wx-open-app",
service.SettingKeyWeChatConnectAppSecret: "wx-open-secret",
service.SettingKeyWeChatConnectMode: "open",
service.SettingKeyWeChatConnectScopes: "snsapi_login",
service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
c.Request.Host = "api.example.com"
handler.WeChatOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotEmpty(t, location)
require.Contains(t, location, "open.weixin.qq.com")
require.Contains(t, location, "appid=wx-open-app")
require.Contains(t, location, "scope=snsapi_login")
cookies := recorder.Result().Cookies()
require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
}
func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
service.SettingKeyWeChatConnectEnabled: "true",
service.SettingKeyWeChatConnectAppID: "wx-shared-app",
service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret",
service.SettingKeyWeChatConnectMode: "mp",
service.SettingKeyWeChatConnectScopes: "snsapi_base",
service.SettingKeyWeChatConnectOpenEnabled: "true",
service.SettingKeyWeChatConnectMPEnabled: "true",
service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
})
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
c.Request.Host = "api.example.com"
handler.WeChatOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotEmpty(t, location)
require.Contains(t, location, "open.weixin.qq.com")
require.Contains(t, location, "connect/qrconnect")
require.Contains(t, location, "scope=snsapi_login")
}
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
ctx := context.Background()
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "wechat", session.ProviderType)
require.Equal(t, "wechat-main", session.ProviderKey)
require.Equal(t, "union-456", session.ProviderSubject)
require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
}
func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(context.Background())
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Equal(t, "openid-123", session.ProviderSubject)
require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail)
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "third_party_signup", completion["choice_reason"])
}
func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-456")).
SetUsername("wechat-existing-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"username": "wechat-existing-user"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completion["redirect"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
return
}
http.NotFound(w, r)
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
c.Request = req
handler.WeChatPaymentOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
require.NotEmpty(t, fragment.Get("wechat_resume_token"))
require.Empty(t, fragment.Get("openid"))
require.Empty(t, fragment.Get("payment_type"))
require.Empty(t, fragment.Get("amount"))
require.Empty(t, fragment.Get("order_type"))
require.Empty(t, fragment.Get("plan_id"))
claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
require.NoError(t, err)
require.Equal(t, "openid-123", claims.OpenID)
require.Equal(t, payment.TypeWxpay, claims.PaymentType)
require.Equal(t, "12.5", claims.Amount)
require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
require.EqualValues(t, 7, claims.PlanID)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
}
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
mode string
appID string
appSecret string
openID string
}{
{
name: "open",
mode: "open",
appID: "wx-open-app",
appSecret: "wx-open-secret",
openID: "openid-open-123",
},
{
name: "mp",
mode: "mp",
appID: "wx-mp-app",
appSecret: "wx-mp-secret",
openID: "openid-mp-123",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback"))
defer client.Close()
currentUser, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("current-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(context.Background())
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode))
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(context.Background())
require.NoError(t, err)
require.Equal(t, wechatOAuthIntentBind, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, currentUser.ID, *session.TargetUserID)
require.Equal(t, currentUser.Email, session.ResolvedEmail)
require.Equal(t, "union-456", session.ProviderSubject)
require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"])
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"])
require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"])
require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"])
require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"])
completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completionResponse["redirect"])
_, hasAccessToken := completionResponse["access_token"]
require.False(t, hasAccessToken)
})
}
}
func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
owner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
currentUser, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("current").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(owner.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"unionid": "union-456"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
owner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
currentUser, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("current").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
ownerIdentity, err := client.AuthIdentity.Create().
SetUserID(owner.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-owner").
SetMetadata(map[string]any{"unionid": "union-owner"}).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentityChannel.Create().
SetIdentityID(ownerIdentity.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetChannel("open").
SetChannelAppID("wx-open-app").
SetChannelSubject("openid-123").
SetMetadata(map[string]any{"openid": "openid-123"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
owner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
currentUser, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("current").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(owner.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthLegacyProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"unionid": "union-456"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, true)
defer client.Close()
ctx := context.Background()
redeemRepo := repository.NewRedeemCodeRepository(client)
require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
Code: "invite-1",
Type: service.RedeemTypeInvitation,
Status: service.StatusUnused,
}))
callbackRecorder := httptest.NewRecorder()
callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
callbackReq.Host = "api.example.com"
callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
callbackCtx.Request = callbackReq
handler.WeChatOAuthCallback(callbackCtx)
require.Equal(t, http.StatusFound, callbackRecorder.Code)
require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
pendingSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"])
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
completeRecorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(completeRecorder)
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
completeReq.Header.Set("Content-Type", "application/json")
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
completeCtx.Request = completeReq
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["adoption_required"])
require.Empty(t, responseData["access_token"])
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.Nil(t, consumed.ConsumedAt)
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, channelCount)
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Count(ctx)
require.NoError(t, err)
require.Zero(t, decisionCount)
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
legacyUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(legacyUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("openid-123").
SetMetadata(map[string]any{"openid": "openid-123"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, session.TargetUserID)
require.Equal(t, legacyUser.ID, *session.TargetUserID)
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
repairedIdentity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
openIDIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("openid-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, openIDIdentityCount)
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-invalid-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("wechat-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-choice-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("wechat-choice-subject-1").
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
SetBrowserSessionKey("wechat-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
legacyUser, err := client.User.Create().
SetEmail("legacy@example.com").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(legacyUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthLegacyProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"unionid": "union-456"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, session.TargetUserID)
require.Equal(t, legacyUser.ID, *session.TargetUserID)
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
repairedIdentity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
legacyIdentityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
authidentity.ProviderSubjectEQ("union-456"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, legacyIdentityCount)
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
}
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil)
}
func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string {
return map[string]string{
service.SettingKeyWeChatConnectEnabled: "true",
service.SettingKeyWeChatConnectAppID: appID,
service.SettingKeyWeChatConnectAppSecret: secret,
service.SettingKeyWeChatConnectMode: mode,
service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode),
service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect,
}
}
func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
userRepo := &oauthPendingFlowUserRepo{client: client}
redeemRepo := repository.NewRedeemCodeRepository(client)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
values := map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
}
for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") {
values[key] = value
}
for key, value := range extraSettings {
values[key] = value
}
settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg)
authSvc := service.NewAuthService(
client,
userRepo,
redeemRepo,
&wechatOAuthRefreshTokenCacheStub{},
cfg,
settingSvc,
nil,
nil,
nil,
nil,
nil,
)
return &AuthHandler{
authService: authSvc,
settingSvc: settingSvc,
cfg: cfg,
}, client
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
require.Equal(t, errorCode, fragment.Get("error"))
require.Equal(t, errorMessage, fragment.Get("error_message"))
}
type wechatOAuthSettingRepoStub struct {
values map[string]string
}
func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
return nil, service.ErrSettingNotFound
}
func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
value, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return value, nil
}
func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
result := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
result[key] = value
}
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
result := make(map[string]string, len(s.values))
for key, value := range s.values {
result[key] = value
}
return result, nil
}
func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
return nil
}
type wechatOAuthRefreshTokenCacheStub struct{}
func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}