1166 lines
44 KiB
Go
1166 lines
44 KiB
Go
//go:build unit
|
|
|
|
package handler
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
|
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
|
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
|
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/payment"
|
|
"github.com/Wei-Shaw/sub2api/internal/repository"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"entgo.io/ent/dialect"
|
|
entsql "entgo.io/ent/dialect/sql"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
|
|
c.Request.Host = "api.example.com"
|
|
|
|
handler := &AuthHandler{}
|
|
handler.WeChatOAuthStart(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
location := recorder.Header().Get("Location")
|
|
require.NotEmpty(t, location)
|
|
require.Contains(t, location, "open.weixin.qq.com")
|
|
require.Contains(t, location, "appid=wx-open-app")
|
|
require.Contains(t, location, "scope=snsapi_login")
|
|
|
|
cookies := recorder.Result().Cookies()
|
|
require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
|
|
require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
|
|
require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
|
|
require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
|
|
|
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
|
require.NotNil(t, sessionCookie)
|
|
|
|
ctx := context.Background()
|
|
session, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "wechat", session.ProviderType)
|
|
require.Equal(t, "wechat-main", session.ProviderKey)
|
|
require.Equal(t, "union-456", session.ProviderSubject)
|
|
require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
|
|
require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
|
|
require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
|
|
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
|
|
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error")
|
|
require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid")
|
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
|
|
|
count, err := client.PendingAuthSession.Query().Count(context.Background())
|
|
require.NoError(t, err)
|
|
require.Zero(t, count)
|
|
}
|
|
|
|
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
|
|
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
|
|
return
|
|
}
|
|
http.NotFound(w, r)
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
|
|
req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
|
|
req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
|
|
c.Request = req
|
|
|
|
handler.WeChatPaymentOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
location := recorder.Header().Get("Location")
|
|
parsed, err := url.Parse(location)
|
|
require.NoError(t, err)
|
|
fragment, err := url.ParseQuery(parsed.Fragment)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
|
|
require.NotEmpty(t, fragment.Get("wechat_resume_token"))
|
|
require.Empty(t, fragment.Get("openid"))
|
|
require.Empty(t, fragment.Get("payment_type"))
|
|
require.Empty(t, fragment.Get("amount"))
|
|
require.Empty(t, fragment.Get("order_type"))
|
|
require.Empty(t, fragment.Get("plan_id"))
|
|
|
|
claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
|
|
require.NoError(t, err)
|
|
require.Equal(t, "openid-123", claims.OpenID)
|
|
require.Equal(t, payment.TypeWxpay, claims.PaymentType)
|
|
require.Equal(t, "12.5", claims.Amount)
|
|
require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
|
|
require.EqualValues(t, 7, claims.PlanID)
|
|
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
mode string
|
|
appIDEnv string
|
|
appID string
|
|
appSecret string
|
|
openID string
|
|
}{
|
|
{
|
|
name: "open",
|
|
mode: "open",
|
|
appIDEnv: "WECHAT_OAUTH_OPEN_APP_ID",
|
|
appID: "wx-open-app",
|
|
appSecret: "wx-open-secret",
|
|
openID: "openid-open-123",
|
|
},
|
|
{
|
|
name: "mp",
|
|
mode: "mp",
|
|
appIDEnv: "WECHAT_OAUTH_MP_APP_ID",
|
|
appID: "wx-mp-app",
|
|
appSecret: "wx-mp-secret",
|
|
openID: "openid-mp-123",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Setenv(tc.appIDEnv, tc.appID)
|
|
switch tc.mode {
|
|
case "open":
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", tc.appSecret)
|
|
case "mp":
|
|
t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", tc.appSecret)
|
|
}
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
currentUser, err := client.User.Create().
|
|
SetEmail("current@example.com").
|
|
SetUsername("current-user").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(context.Background())
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode))
|
|
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
|
|
|
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
|
require.NotNil(t, sessionCookie)
|
|
|
|
session, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
|
Only(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, wechatOAuthIntentBind, session.Intent)
|
|
require.NotNil(t, session.TargetUserID)
|
|
require.Equal(t, currentUser.ID, *session.TargetUserID)
|
|
require.Equal(t, currentUser.Email, session.ResolvedEmail)
|
|
require.Equal(t, "union-456", session.ProviderSubject)
|
|
require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"])
|
|
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
|
|
require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"])
|
|
require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"])
|
|
require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"])
|
|
require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"])
|
|
|
|
completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
|
|
require.Equal(t, "/dashboard", completionResponse["redirect"])
|
|
_, hasAccessToken := completionResponse["access_token"]
|
|
require.False(t, hasAccessToken)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
owner, err := client.User.Create().
|
|
SetEmail("owner@example.com").
|
|
SetUsername("owner").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
currentUser, err := client.User.Create().
|
|
SetEmail("current@example.com").
|
|
SetUsername("current").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
_, err = client.AuthIdentity.Create().
|
|
SetUserID(owner.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthProviderKey).
|
|
SetProviderSubject("union-456").
|
|
SetMetadata(map[string]any{"unionid": "union-456"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
|
|
|
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
|
require.NoError(t, err)
|
|
require.Zero(t, count)
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
owner, err := client.User.Create().
|
|
SetEmail("owner@example.com").
|
|
SetUsername("owner").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
currentUser, err := client.User.Create().
|
|
SetEmail("current@example.com").
|
|
SetUsername("current").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
ownerIdentity, err := client.AuthIdentity.Create().
|
|
SetUserID(owner.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthProviderKey).
|
|
SetProviderSubject("union-owner").
|
|
SetMetadata(map[string]any{"unionid": "union-owner"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
_, err = client.AuthIdentityChannel.Create().
|
|
SetIdentityID(ownerIdentity.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthProviderKey).
|
|
SetChannel("open").
|
|
SetChannelAppID("wx-open-app").
|
|
SetChannelSubject("openid-123").
|
|
SetMetadata(map[string]any{"openid": "openid-123"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT")
|
|
|
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
|
require.NoError(t, err)
|
|
require.Zero(t, count)
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
owner, err := client.User.Create().
|
|
SetEmail("owner@example.com").
|
|
SetUsername("owner").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
currentUser, err := client.User.Create().
|
|
SetEmail("current@example.com").
|
|
SetUsername("current").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
_, err = client.AuthIdentity.Create().
|
|
SetUserID(owner.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthLegacyProviderKey).
|
|
SetProviderSubject("union-456").
|
|
SetMetadata(map[string]any{"unionid": "union-456"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
|
|
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
|
|
|
|
count, err := client.PendingAuthSession.Query().Count(ctx)
|
|
require.NoError(t, err)
|
|
require.Zero(t, count)
|
|
}
|
|
|
|
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, true)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
redeemRepo := repository.NewRedeemCodeRepository(client)
|
|
require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
|
|
Code: "invite-1",
|
|
Type: service.RedeemTypeInvitation,
|
|
Status: service.StatusUnused,
|
|
}))
|
|
|
|
callbackRecorder := httptest.NewRecorder()
|
|
callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
|
|
callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
callbackReq.Host = "api.example.com"
|
|
callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
callbackCtx.Request = callbackReq
|
|
|
|
handler.WeChatOAuthCallback(callbackCtx)
|
|
|
|
require.Equal(t, http.StatusFound, callbackRecorder.Code)
|
|
require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
|
|
|
|
sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
|
|
require.NotNil(t, sessionCookie)
|
|
sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
|
|
|
|
pendingSession, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
|
|
|
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
|
|
completeRecorder := httptest.NewRecorder()
|
|
completeCtx, _ := gin.CreateTestContext(completeRecorder)
|
|
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
|
|
completeReq.Header.Set("Content-Type", "application/json")
|
|
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
|
|
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
|
|
completeCtx.Request = completeReq
|
|
|
|
handler.CompleteWeChatOAuthRegistration(completeCtx)
|
|
|
|
require.Equal(t, http.StatusOK, completeRecorder.Code)
|
|
responseData := decodeJSONBody(t, completeRecorder)
|
|
require.NotEmpty(t, responseData["access_token"])
|
|
|
|
userEntity, err := client.User.Query().
|
|
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "WeChat Display", userEntity.Username)
|
|
|
|
identity, err := client.AuthIdentity.Query().
|
|
Where(
|
|
authidentity.ProviderTypeEQ("wechat"),
|
|
authidentity.ProviderKeyEQ("wechat-main"),
|
|
authidentity.ProviderSubjectEQ("union-456"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, userEntity.ID, identity.UserID)
|
|
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
|
|
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
|
|
|
|
channel, err := client.AuthIdentityChannel.Query().
|
|
Where(
|
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
|
authidentitychannel.ProviderKeyEQ("wechat-main"),
|
|
authidentitychannel.ChannelEQ("open"),
|
|
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, identity.ID, channel.IdentityID)
|
|
require.Equal(t, "union-456", channel.Metadata["unionid"])
|
|
|
|
decision, err := client.IdentityAdoptionDecision.Query().
|
|
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, decision.IdentityID)
|
|
require.Equal(t, identity.ID, *decision.IdentityID)
|
|
require.True(t, decision.AdoptDisplayName)
|
|
require.True(t, decision.AdoptAvatar)
|
|
|
|
consumed, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.IDEQ(pendingSession.ID)).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, consumed.ConsumedAt)
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
legacyUser, err := client.User.Create().
|
|
SetEmail("legacy@example.com").
|
|
SetUsername("legacy-user").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
legacyIdentity, err := client.AuthIdentity.Create().
|
|
SetUserID(legacyUser.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthProviderKey).
|
|
SetProviderSubject("openid-123").
|
|
SetMetadata(map[string]any{"openid": "openid-123"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
|
|
|
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
|
require.NotNil(t, sessionCookie)
|
|
|
|
session, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, session.TargetUserID)
|
|
require.Equal(t, legacyUser.ID, *session.TargetUserID)
|
|
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
|
|
|
|
repairedIdentity, err := client.AuthIdentity.Query().
|
|
Where(
|
|
authidentity.ProviderTypeEQ("wechat"),
|
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
|
authidentity.ProviderSubjectEQ("union-456"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
|
|
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
|
|
|
|
openIDIdentityCount, err := client.AuthIdentity.Query().
|
|
Where(
|
|
authidentity.ProviderTypeEQ("wechat"),
|
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
|
authidentity.ProviderSubjectEQ("openid-123"),
|
|
).
|
|
Count(ctx)
|
|
require.NoError(t, err)
|
|
require.Zero(t, openIDIdentityCount)
|
|
|
|
channel, err := client.AuthIdentityChannel.Query().
|
|
Where(
|
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
|
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
|
|
authidentitychannel.ChannelEQ("open"),
|
|
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
|
|
}
|
|
|
|
func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
existingUser, err := client.User.Create().
|
|
SetEmail("owner@example.com").
|
|
SetUsername("owner-user").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
session, err := client.PendingAuthSession.Create().
|
|
SetSessionToken("wechat-complete-invalid-session").
|
|
SetIntent("adopt_existing_user_by_email").
|
|
SetProviderType("wechat").
|
|
SetProviderKey("wechat-main").
|
|
SetProviderSubject("union-invalid-1").
|
|
SetTargetUserID(existingUser.ID).
|
|
SetResolvedEmail(existingUser.Email).
|
|
SetBrowserSessionKey("wechat-invalid-browser").
|
|
SetUpstreamIdentityClaims(map[string]any{
|
|
"username": "wechat_user",
|
|
}).
|
|
SetLocalFlowState(map[string]any{
|
|
oauthCompletionResponseKey: map[string]any{
|
|
"step": "bind_login_required",
|
|
},
|
|
}).
|
|
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
|
|
recorder := httptest.NewRecorder()
|
|
completeCtx, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
|
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
|
|
completeCtx.Request = req
|
|
|
|
handler.CompleteWeChatOAuthRegistration(completeCtx)
|
|
|
|
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
|
|
|
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
|
|
require.NoError(t, err)
|
|
require.Nil(t, storedSession.ConsumedAt)
|
|
}
|
|
|
|
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
|
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
|
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
|
|
|
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
|
originalUserInfoURL := wechatOAuthUserInfoURL
|
|
t.Cleanup(func() {
|
|
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
|
wechatOAuthUserInfoURL = originalUserInfoURL
|
|
})
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
|
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer upstream.Close()
|
|
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
|
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
|
|
|
handler, client := newWeChatOAuthTestHandler(t, false)
|
|
defer client.Close()
|
|
|
|
ctx := context.Background()
|
|
legacyUser, err := client.User.Create().
|
|
SetEmail("legacy@example.com").
|
|
SetUsername("legacy-user").
|
|
SetPasswordHash("hash").
|
|
SetRole(service.RoleUser).
|
|
SetStatus(service.StatusActive).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
legacyIdentity, err := client.AuthIdentity.Create().
|
|
SetUserID(legacyUser.ID).
|
|
SetProviderType("wechat").
|
|
SetProviderKey(wechatOAuthLegacyProviderKey).
|
|
SetProviderSubject("union-456").
|
|
SetMetadata(map[string]any{"unionid": "union-456"}).
|
|
Save(ctx)
|
|
require.NoError(t, err)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
|
req.Host = "api.example.com"
|
|
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
|
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
|
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
|
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
|
c.Request = req
|
|
|
|
handler.WeChatOAuthCallback(c)
|
|
|
|
require.Equal(t, http.StatusFound, recorder.Code)
|
|
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
|
|
|
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
|
require.NotNil(t, sessionCookie)
|
|
|
|
session, err := client.PendingAuthSession.Query().
|
|
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, session.TargetUserID)
|
|
require.Equal(t, legacyUser.ID, *session.TargetUserID)
|
|
require.Equal(t, legacyUser.Email, session.ResolvedEmail)
|
|
|
|
repairedIdentity, err := client.AuthIdentity.Query().
|
|
Where(
|
|
authidentity.ProviderTypeEQ("wechat"),
|
|
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
|
|
authidentity.ProviderSubjectEQ("union-456"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
|
|
require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
|
|
|
|
legacyIdentityCount, err := client.AuthIdentity.Query().
|
|
Where(
|
|
authidentity.ProviderTypeEQ("wechat"),
|
|
authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
|
|
authidentity.ProviderSubjectEQ("union-456"),
|
|
).
|
|
Count(ctx)
|
|
require.NoError(t, err)
|
|
require.Zero(t, legacyIdentityCount)
|
|
|
|
channel, err := client.AuthIdentityChannel.Query().
|
|
Where(
|
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
|
authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
|
|
authidentitychannel.ChannelEQ("open"),
|
|
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
|
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
|
).
|
|
Only(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, repairedIdentity.ID, channel.IdentityID)
|
|
}
|
|
|
|
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
|
t.Helper()
|
|
|
|
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { _ = db.Close() })
|
|
|
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
|
require.NoError(t, err)
|
|
|
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
|
|
|
userRepo := &oauthPendingFlowUserRepo{client: client}
|
|
redeemRepo := repository.NewRedeemCodeRepository(client)
|
|
cfg := &config.Config{
|
|
JWT: config.JWTConfig{
|
|
Secret: "test-secret",
|
|
ExpireHour: 1,
|
|
AccessTokenExpireMinutes: 60,
|
|
RefreshTokenExpireDays: 7,
|
|
},
|
|
Default: config.DefaultConfig{
|
|
UserBalance: 0,
|
|
UserConcurrency: 1,
|
|
},
|
|
}
|
|
settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
|
|
values: map[string]string{
|
|
service.SettingKeyRegistrationEnabled: "true",
|
|
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
|
|
},
|
|
}, cfg)
|
|
|
|
authSvc := service.NewAuthService(
|
|
client,
|
|
userRepo,
|
|
redeemRepo,
|
|
&wechatOAuthRefreshTokenCacheStub{},
|
|
cfg,
|
|
settingSvc,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
)
|
|
|
|
return &AuthHandler{
|
|
authService: authSvc,
|
|
settingSvc: settingSvc,
|
|
cfg: cfg,
|
|
}, client
|
|
}
|
|
|
|
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
|
|
t.Helper()
|
|
|
|
parsed, err := url.Parse(location)
|
|
require.NoError(t, err)
|
|
|
|
fragment, err := url.ParseQuery(parsed.Fragment)
|
|
require.NoError(t, err)
|
|
require.Equal(t, errorCode, fragment.Get("error"))
|
|
require.Equal(t, errorMessage, fragment.Get("error_message"))
|
|
}
|
|
|
|
type wechatOAuthSettingRepoStub struct {
|
|
values map[string]string
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
|
|
return nil, service.ErrSettingNotFound
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
|
value, ok := s.values[key]
|
|
if !ok {
|
|
return "", service.ErrSettingNotFound
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
|
result := make(map[string]string, len(keys))
|
|
for _, key := range keys {
|
|
if value, ok := s.values[key]; ok {
|
|
result[key] = value
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
|
result := make(map[string]string, len(s.values))
|
|
for key, value := range s.values {
|
|
result[key] = value
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
|
|
return nil
|
|
}
|
|
|
|
type wechatOAuthRefreshTokenCacheStub struct{}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
|
|
return nil, service.ErrRefreshTokenNotFound
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
|
|
return false, nil
|
|
}
|