feat: rebuild auth identity foundation flow
This commit is contained in:
@@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if ops monitoring is enabled (respects config.ops.enabled)
|
||||
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
|
||||
@@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
paymentCfg = &service.PaymentConfig{}
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
payload := dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
|
||||
@@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
|
||||
})
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
|
||||
}
|
||||
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
@@ -276,9 +282,30 @@ type UpdateSettingsRequest struct {
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
|
||||
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
|
||||
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
|
||||
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
|
||||
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
|
||||
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
|
||||
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
|
||||
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
|
||||
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
|
||||
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
|
||||
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
|
||||
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
|
||||
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
|
||||
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
|
||||
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证参数
|
||||
if req.DefaultConcurrency < 1 {
|
||||
@@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.SMTPPort = 587
|
||||
}
|
||||
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
|
||||
req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
|
||||
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
|
||||
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
|
||||
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
|
||||
|
||||
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
|
||||
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
|
||||
@@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.BadRequest(c, "OIDC scopes must contain openid")
|
||||
return
|
||||
}
|
||||
if !req.OIDCConnectUsePKCE {
|
||||
response.BadRequest(c, "OIDC PKCE must be enabled")
|
||||
return
|
||||
}
|
||||
if !req.OIDCConnectValidateIDToken {
|
||||
response.BadRequest(c, "OIDC ID Token validation must be enabled")
|
||||
return
|
||||
}
|
||||
switch req.OIDCConnectTokenAuthMethod {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
|
||||
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
|
||||
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectValidateIDToken {
|
||||
if req.OIDCConnectAllowedSigningAlgs == "" {
|
||||
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectAllowedSigningAlgs == "" {
|
||||
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
|
||||
return
|
||||
}
|
||||
if req.OIDCConnectJWKSURL != "" {
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
|
||||
@@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
authSourceDefaults := &service.AuthSourceDefaultSettings{
|
||||
Email: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
|
||||
},
|
||||
LinuxDo: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
|
||||
},
|
||||
OIDC: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
|
||||
},
|
||||
WeChat: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update payment configuration (integrated into system settings).
|
||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||
@@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
|
||||
for _, sub := range updatedSettings.DefaultSubscriptions {
|
||||
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
|
||||
@@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
updatedPaymentCfg = &service.PaymentConfig{}
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
payload := dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
|
||||
@@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
|
||||
})
|
||||
}
|
||||
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
|
||||
}
|
||||
|
||||
// hasPaymentFields returns true if any payment-related field was explicitly provided.
|
||||
@@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
|
||||
if input == nil {
|
||||
return nil
|
||||
}
|
||||
normalized := normalizeDefaultSubscriptions(*input)
|
||||
return &normalized
|
||||
}
|
||||
|
||||
func float64ValueOrDefault(value *float64, fallback float64) float64 {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func intValueOrDefault(value *int, fallback int) int {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func boolValueOrDefault(value *bool, fallback bool) bool {
|
||||
if value == nil {
|
||||
return fallback
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
|
||||
if input == nil {
|
||||
return fallback
|
||||
}
|
||||
result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
|
||||
for _, item := range *input {
|
||||
result = append(result, service.DefaultSubscriptionSetting{
|
||||
GroupID: item.GroupID,
|
||||
ValidityDays: item.ValidityDays,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
|
||||
data := make(map[string]any)
|
||||
raw, err := json.Marshal(settings)
|
||||
if err == nil {
|
||||
_ = json.Unmarshal(raw, &data)
|
||||
}
|
||||
if authSourceDefaults == nil {
|
||||
authSourceDefaults = &service.AuthSourceDefaultSettings{}
|
||||
}
|
||||
|
||||
data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
|
||||
data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
|
||||
data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
|
||||
data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
|
||||
data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
|
||||
data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
|
||||
data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
|
||||
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
|
||||
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
|
||||
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
|
||||
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
|
||||
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
|
||||
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
|
||||
data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
|
||||
data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
|
||||
data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
|
||||
data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
|
||||
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
|
||||
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
|
||||
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
|
||||
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func equalStringSlice(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type settingHandlerRepoStub struct {
|
||||
values map[string]string
|
||||
lastUpdates map[string]string
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.lastUpdates = make(map[string]string, len(settings))
|
||||
for key, value := range settings {
|
||||
s.lastUpdates[key] = value
|
||||
if s.values == nil {
|
||||
s.values = map[string]string{}
|
||||
}
|
||||
s.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
out := make(map[string]string, len(s.values))
|
||||
for key, value := range s.values {
|
||||
out[key] = value
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &settingHandlerRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
|
||||
service.SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
|
||||
|
||||
handler.GetSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 9.5, data["auth_source_default_email_balance"])
|
||||
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
|
||||
require.Equal(t, true, data["force_email_on_third_party_signup"])
|
||||
|
||||
subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, subscriptions, 1)
|
||||
}
|
||||
|
||||
func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := &settingHandlerRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
|
||||
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
|
||||
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
|
||||
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
|
||||
service.SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"registration_enabled": true,
|
||||
"promo_code_enabled": true,
|
||||
"auth_source_default_email_balance": 12.75,
|
||||
}
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
|
||||
require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
|
||||
require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
|
||||
require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 12.75, data["auth_source_default_email_balance"])
|
||||
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
|
||||
require.Equal(t, true, data["force_email_on_third_party_signup"])
|
||||
}
|
||||
@@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
@@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
@@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
type completeLinuxDoOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
|
||||
@@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: req.AdoptDisplayName,
|
||||
AdoptAvatar: req.AdoptAvatar,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||
return
|
||||
}
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
|
||||
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
|
||||
require.Equal(t, "", singleLine("\n\t\r"))
|
||||
}
|
||||
|
||||
func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("linuxdo-complete-session").
|
||||
SetIntent("login").
|
||||
SetProviderType("linuxdo").
|
||||
SetProviderKey("linuxdo").
|
||||
SetProviderSubject("linuxdo-subject-1").
|
||||
SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
|
||||
SetBrowserSessionKey("linuxdo-browser").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "linuxdo_user",
|
||||
"suggested_display_name": "LinuxDo Display",
|
||||
"suggested_avatar_url": "https://cdn.example/linuxdo.png",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
|
||||
c.Request = req
|
||||
|
||||
handler.CompleteLinuxDoOAuthRegistration(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
responseData := decodeJSONBody(t, recorder)
|
||||
require.NotEmpty(t, responseData["access_token"])
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ(session.ResolvedEmail)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "LinuxDo Display", userEntity.Username)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("linuxdo"),
|
||||
authidentity.ProviderKeyEQ("linuxdo"),
|
||||
authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
|
||||
require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
require.True(t, decision.AdoptDisplayName)
|
||||
require.True(t, decision.AdoptAvatar)
|
||||
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -26,6 +33,7 @@ const (
|
||||
type oauthPendingSessionPayload struct {
|
||||
Intent string
|
||||
Identity service.PendingAuthIdentityKey
|
||||
TargetUserID *int64
|
||||
ResolvedEmail string
|
||||
RedirectTo string
|
||||
BrowserSessionKey string
|
||||
@@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct {
|
||||
CompletionResponse map[string]any
|
||||
}
|
||||
|
||||
type oauthAdoptionDecisionRequest struct {
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
|
||||
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
@@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
|
||||
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
|
||||
Intent: strings.TrimSpace(payload.Intent),
|
||||
Identity: payload.Identity,
|
||||
TargetUserID: payload.TargetUserID,
|
||||
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
|
||||
RedirectTo: strings.TrimSpace(payload.RedirectTo),
|
||||
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
|
||||
@@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
|
||||
}
|
||||
|
||||
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
|
||||
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
|
||||
}
|
||||
|
||||
func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput {
|
||||
input := service.PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: sessionID,
|
||||
}
|
||||
if r.AdoptDisplayName != nil {
|
||||
input.AdoptDisplayName = *r.AdoptDisplayName
|
||||
}
|
||||
if r.AdoptAvatar != nil {
|
||||
input.AdoptAvatar = *r.AdoptAvatar
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
|
||||
var req oauthAdoptionDecisionRequest
|
||||
if c == nil || c.Request == nil || c.Request.Body == nil {
|
||||
return req, nil
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return req, nil
|
||||
}
|
||||
return req, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func persistPendingOAuthAdoptionDecision(
|
||||
c *gin.Context,
|
||||
svc *service.AuthPendingIdentityService,
|
||||
sessionID int64,
|
||||
req oauthAdoptionDecisionRequest,
|
||||
) error {
|
||||
if !req.hasDecision() {
|
||||
return nil
|
||||
}
|
||||
if svc == nil {
|
||||
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil {
|
||||
return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneOAuthMetadata(values map[string]any) map[string]any {
|
||||
if len(values) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
cloned := make(map[string]any, len(values))
|
||||
for key, value := range values {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func normalizeAdoptedOAuthDisplayName(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if len([]rune(value)) > 100 {
|
||||
value = string([]rune(value)[:100])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (h *AuthHandler) entClient() *dbent.Client {
|
||||
if h == nil || h.authService == nil {
|
||||
return nil
|
||||
}
|
||||
return h.authService.EntClient()
|
||||
}
|
||||
|
||||
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
|
||||
c *gin.Context,
|
||||
sessionID int64,
|
||||
req oauthAdoptionDecisionRequest,
|
||||
) (*dbent.IdentityAdoptionDecision, error) {
|
||||
client := h.entClient()
|
||||
if client == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
|
||||
}
|
||||
|
||||
existing, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
|
||||
Only(c.Request.Context())
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
|
||||
}
|
||||
if existing != nil && !req.hasDecision() {
|
||||
return existing, nil
|
||||
}
|
||||
if existing == nil && !req.hasDecision() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
input := service.PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: sessionID,
|
||||
}
|
||||
if existing != nil {
|
||||
input.AdoptDisplayName = existing.AdoptDisplayName
|
||||
input.AdoptAvatar = existing.AdoptAvatar
|
||||
input.IdentityID = existing.IdentityID
|
||||
}
|
||||
if req.AdoptDisplayName != nil {
|
||||
input.AdoptDisplayName = *req.AdoptDisplayName
|
||||
}
|
||||
if req.AdoptAvatar != nil {
|
||||
input.AdoptAvatar = *req.AdoptAvatar
|
||||
}
|
||||
|
||||
svc, err := h.pendingIdentityService()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
|
||||
}
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
|
||||
if session == nil {
|
||||
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
|
||||
}
|
||||
if session.TargetUserID != nil && *session.TargetUserID > 0 {
|
||||
return *session.TargetUserID, nil
|
||||
}
|
||||
email := strings.TrimSpace(session.ResolvedEmail)
|
||||
if email == "" {
|
||||
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
|
||||
}
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ(email)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return userEntity.ID, nil
|
||||
}
|
||||
|
||||
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
switch strings.TrimSpace(session.ProviderType) {
|
||||
case "oidc":
|
||||
issuer := strings.TrimSpace(session.ProviderKey)
|
||||
if issuer == "" {
|
||||
issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
|
||||
}
|
||||
if issuer == "" {
|
||||
return nil
|
||||
}
|
||||
return &issuer
|
||||
default:
|
||||
issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
|
||||
if issuer == "" {
|
||||
return nil
|
||||
}
|
||||
return &issuer
|
||||
}
|
||||
}
|
||||
|
||||
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
|
||||
client := tx.Client()
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
|
||||
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
|
||||
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if identity != nil {
|
||||
if identity.UserID != userID {
|
||||
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||
}
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
create := client.AuthIdentity.Create().
|
||||
SetUserID(userID).
|
||||
SetProviderType(strings.TrimSpace(session.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(session.ProviderKey)).
|
||||
SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
|
||||
SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
create = create.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
func applyPendingOAuthAdoption(
|
||||
ctx context.Context,
|
||||
client *dbent.Client,
|
||||
session *dbent.PendingAuthSession,
|
||||
decision *dbent.IdentityAdoptionDecision,
|
||||
overrideUserID *int64,
|
||||
) error {
|
||||
if client == nil || session == nil || decision == nil {
|
||||
return nil
|
||||
}
|
||||
if !decision.AdoptDisplayName && !decision.AdoptAvatar {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetUserID := int64(0)
|
||||
if overrideUserID != nil && *overrideUserID > 0 {
|
||||
targetUserID = *overrideUserID
|
||||
} else {
|
||||
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targetUserID = resolvedUserID
|
||||
}
|
||||
|
||||
adoptedDisplayName := ""
|
||||
if decision.AdoptDisplayName {
|
||||
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
|
||||
}
|
||||
adoptedAvatarURL := ""
|
||||
if decision.AdoptAvatar {
|
||||
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
|
||||
}
|
||||
|
||||
tx, err := client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
if err := tx.Client().User.UpdateOneID(targetUserID).
|
||||
SetUsername(adoptedDisplayName).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
metadata := cloneOAuthMetadata(identity.Metadata)
|
||||
for key, value := range session.UpstreamIdentityClaims {
|
||||
metadata[key] = value
|
||||
}
|
||||
if decision.AdoptDisplayName && adoptedDisplayName != "" {
|
||||
metadata["display_name"] = adoptedDisplayName
|
||||
}
|
||||
if decision.AdoptAvatar && adoptedAvatarURL != "" {
|
||||
metadata["avatar_url"] = adoptedAvatarURL
|
||||
}
|
||||
|
||||
updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
|
||||
if issuer := oauthIdentityIssuer(session); issuer != nil {
|
||||
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
|
||||
}
|
||||
if _, err := updateIdentity.Save(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if decision.IdentityID == nil || *decision.IdentityID != identity.ID {
|
||||
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
|
||||
SetIdentityID(identity.ID).
|
||||
Save(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
|
||||
if len(payload) == 0 || len(upstream) == 0 {
|
||||
return
|
||||
@@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
}
|
||||
adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken, err := readOAuthPendingSessionCookie(c)
|
||||
if err != nil || strings.TrimSpace(sessionToken) == "" {
|
||||
@@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
|
||||
|
||||
if pendingSessionWantsInvitation(payload) {
|
||||
if adoptionDecision.hasDecision() {
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
_ = decision
|
||||
}
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if !adoptionDecision.hasDecision() {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
||||
clearCookies()
|
||||
|
||||
@@ -1,9 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
|
||||
@@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *
|
||||
require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
|
||||
require.Equal(t, true, payload["adoption_required"])
|
||||
}
|
||||
|
||||
func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
userEntity, err := client.User.Create().
|
||||
SetEmail("linuxdo-123@linuxdo-connect.invalid").
|
||||
SetUsername("legacy-name").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("pending-session-token").
|
||||
SetIntent("login").
|
||||
SetProviderType("linuxdo").
|
||||
SetProviderKey("linuxdo").
|
||||
SetProviderSubject("123").
|
||||
SetTargetUserID(userEntity.ID).
|
||||
SetResolvedEmail(userEntity.Email).
|
||||
SetBrowserSessionKey("browser-session-key").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "linuxdo_user",
|
||||
"suggested_display_name": "Alice Example",
|
||||
"suggested_avatar_url": "https://cdn.example/alice.png",
|
||||
}).
|
||||
SetLocalFlowState(map[string]any{
|
||||
oauthCompletionResponseKey: map[string]any{
|
||||
"access_token": "access-token",
|
||||
"redirect": "/dashboard",
|
||||
},
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
previewRecorder := httptest.NewRecorder()
|
||||
previewCtx, _ := gin.CreateTestContext(previewRecorder)
|
||||
previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
|
||||
previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
|
||||
previewCtx.Request = previewReq
|
||||
|
||||
handler.ExchangePendingOAuthCompletion(previewCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, previewRecorder.Code)
|
||||
previewData := decodeJSONResponseData(t, previewRecorder)
|
||||
require.Equal(t, "Alice Example", previewData["suggested_display_name"])
|
||||
require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
|
||||
require.Equal(t, true, previewData["adoption_required"])
|
||||
|
||||
storedUser, err := client.User.Get(ctx, userEntity.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "legacy-name", storedUser.Username)
|
||||
|
||||
previewSession, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, previewSession.ConsumedAt)
|
||||
|
||||
body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
|
||||
finalizeRecorder := httptest.NewRecorder()
|
||||
finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
|
||||
finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
|
||||
finalizeReq.Header.Set("Content-Type", "application/json")
|
||||
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
|
||||
finalizeCtx.Request = finalizeReq
|
||||
|
||||
handler.ExchangePendingOAuthCompletion(finalizeCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, finalizeRecorder.Code)
|
||||
|
||||
storedUser, err = client.User.Get(ctx, userEntity.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Alice Example", storedUser.Username)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("linuxdo"),
|
||||
authidentity.ProviderKeyEQ("linuxdo"),
|
||||
authidentity.ProviderSubjectEQ("123"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
require.Equal(t, "Alice Example", identity.Metadata["display_name"])
|
||||
require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
require.True(t, decision.AdoptDisplayName)
|
||||
require.True(t, decision.AdoptAvatar)
|
||||
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
}
|
||||
|
||||
func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
}
|
||||
settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
|
||||
},
|
||||
}, cfg)
|
||||
authSvc := service.NewAuthService(
|
||||
client,
|
||||
&oauthPendingFlowUserRepo{client: client},
|
||||
nil,
|
||||
&oauthPendingFlowRefreshTokenCacheStub{},
|
||||
cfg,
|
||||
settingSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
authService: authSvc,
|
||||
settingSvc: settingSvc,
|
||||
}, client
|
||||
}
|
||||
|
||||
func boolSettingValue(v bool) string {
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
type oauthPendingFlowSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
value, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
result := make(map[string]string, len(s.values))
|
||||
for key, value := range s.values {
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type oauthPendingFlowRefreshTokenCacheStub struct{}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var envelope struct {
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
|
||||
return envelope.Data
|
||||
}
|
||||
|
||||
func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
var payload map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
|
||||
return payload
|
||||
}
|
||||
|
||||
type oauthPendingFlowUserRepo struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
entity, err := r.client.User.Create().
|
||||
SetEmail(user.Email).
|
||||
SetUsername(user.Username).
|
||||
SetNotes(user.Notes).
|
||||
SetPasswordHash(user.PasswordHash).
|
||||
SetRole(user.Role).
|
||||
SetBalance(user.Balance).
|
||||
SetConcurrency(user.Concurrency).
|
||||
SetStatus(user.Status).
|
||||
SetSignupSource(user.SignupSource).
|
||||
SetNillableLastLoginAt(user.LastLoginAt).
|
||||
SetNillableLastActiveAt(user.LastActiveAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.ID = entity.ID
|
||||
user.CreatedAt = entity.CreatedAt
|
||||
user.UpdatedAt = entity.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
entity, err := r.client.User.Get(ctx, id)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return oauthPendingFlowServiceUser(entity), nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return oauthPendingFlowServiceUser(entity), nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
|
||||
entity, err := r.client.User.UpdateOneID(user.ID).
|
||||
SetEmail(user.Email).
|
||||
SetUsername(user.Username).
|
||||
SetNotes(user.Notes).
|
||||
SetPasswordHash(user.PasswordHash).
|
||||
SetRole(user.Role).
|
||||
SetBalance(user.Balance).
|
||||
SetConcurrency(user.Concurrency).
|
||||
SetStatus(user.Status).
|
||||
SetSignupSource(user.SignupSource).
|
||||
SetNillableLastLoginAt(user.LastLoginAt).
|
||||
SetNillableLastActiveAt(user.LastActiveAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.UpdatedAt = entity.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
return r.client.User.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
|
||||
panic("unexpected UpsertUserAvatar call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
|
||||
panic("unexpected AddGroupToAllowedGroups call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
panic("unexpected RemoveGroupFromUserAllowedGroups call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
|
||||
if entity == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: entity.ID,
|
||||
Email: entity.Email,
|
||||
Username: entity.Username,
|
||||
Notes: entity.Notes,
|
||||
PasswordHash: entity.PasswordHash,
|
||||
Role: entity.Role,
|
||||
Balance: entity.Balance,
|
||||
Concurrency: entity.Concurrency,
|
||||
Status: entity.Status,
|
||||
SignupSource: entity.SignupSource,
|
||||
LastLoginAt: entity.LastLoginAt,
|
||||
LastActiveAt: entity.LastActiveAt,
|
||||
CreatedAt: entity.CreatedAt,
|
||||
UpdatedAt: entity.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
)
|
||||
|
||||
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrOAuthInvitationRequired) {
|
||||
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
@@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
ProviderKey: issuer,
|
||||
ProviderSubject: subject,
|
||||
},
|
||||
TargetUserID: &user.ID,
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
@@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
type completeOIDCOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
|
||||
@@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: req.AdoptDisplayName,
|
||||
AdoptAvatar: req.AdoptAvatar,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||
return
|
||||
}
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -12,7 +13,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
|
||||
E: e,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := client.PendingAuthSession.Create().
|
||||
SetSessionToken("oidc-complete-session").
|
||||
SetIntent("login").
|
||||
SetProviderType("oidc").
|
||||
SetProviderKey("https://issuer.example.com").
|
||||
SetProviderSubject("oidc-subject-1").
|
||||
SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
|
||||
SetBrowserSessionKey("oidc-browser").
|
||||
SetUpstreamIdentityClaims(map[string]any{
|
||||
"username": "oidc_user",
|
||||
"issuer": "https://issuer.example.com",
|
||||
"suggested_display_name": "OIDC Display",
|
||||
"suggested_avatar_url": "https://cdn.example/oidc.png",
|
||||
}).
|
||||
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
|
||||
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
|
||||
c.Request = req
|
||||
|
||||
handler.CompleteOIDCOAuthRegistration(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
responseData := decodeJSONBody(t, recorder)
|
||||
require.NotEmpty(t, responseData["access_token"])
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ(session.ResolvedEmail)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "OIDC Display", userEntity.Username)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example.com"),
|
||||
authidentity.ProviderSubjectEQ("oidc-subject-1"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
|
||||
require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
require.True(t, decision.AdoptDisplayName)
|
||||
require.True(t, decision.AdoptAvatar)
|
||||
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(session.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
}
|
||||
|
||||
618
backend/internal/handler/auth_wechat_oauth.go
Normal file
618
backend/internal/handler/auth_wechat_oauth.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
|
||||
wechatOAuthCookieMaxAgeSec = 10 * 60
|
||||
wechatOAuthStateCookieName = "wechat_oauth_state"
|
||||
wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
|
||||
wechatOAuthIntentCookieName = "wechat_oauth_intent"
|
||||
wechatOAuthModeCookieName = "wechat_oauth_mode"
|
||||
wechatOAuthDefaultRedirectTo = "/dashboard"
|
||||
wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
|
||||
wechatOAuthProviderKey = "wechat-main"
|
||||
|
||||
wechatOAuthIntentLogin = "login"
|
||||
wechatOAuthIntentBind = "bind_current_user"
|
||||
wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
|
||||
)
|
||||
|
||||
var (
|
||||
wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
|
||||
wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
|
||||
)
|
||||
|
||||
type wechatOAuthConfig struct {
|
||||
mode string
|
||||
appID string
|
||||
appSecret string
|
||||
authorizeURL string
|
||||
scope string
|
||||
redirectURI string
|
||||
frontendCallback string
|
||||
}
|
||||
|
||||
type wechatOAuthTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
OpenID string `json:"openid"`
|
||||
Scope string `json:"scope"`
|
||||
UnionID string `json:"unionid"`
|
||||
ErrCode int64 `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
type wechatOAuthUserInfoResponse struct {
|
||||
OpenID string `json:"openid"`
|
||||
Nickname string `json:"nickname"`
|
||||
HeadImgURL string `json:"headimgurl"`
|
||||
UnionID string `json:"unionid"`
|
||||
ErrCode int64 `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
|
||||
// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
|
||||
// browser cookies required by the rebuild pending-auth bridge.
|
||||
func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
|
||||
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
|
||||
if redirectTo == "" {
|
||||
redirectTo = wechatOAuthDefaultRedirectTo
|
||||
}
|
||||
|
||||
browserSessionKey, err := generateOAuthPendingBrowserSession()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
intent := normalizeWeChatOAuthIntent(c.Query("intent"))
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
|
||||
wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
|
||||
wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
|
||||
wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
|
||||
setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
|
||||
authURL, err := buildWeChatAuthorizeURL(cfg, state)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
|
||||
// and stores the result in the unified pending-auth flow.
|
||||
func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
|
||||
frontendCallback := wechatOAuthFrontendCallback()
|
||||
|
||||
if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(c.Query("code"))
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if code == "" || state == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
defer func() {
|
||||
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
|
||||
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
|
||||
}()
|
||||
|
||||
expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
|
||||
if err != nil || expectedState == "" || state != expectedState {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
|
||||
return
|
||||
}
|
||||
|
||||
redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
|
||||
redirectTo = sanitizeFrontendRedirectPath(redirectTo)
|
||||
if redirectTo == "" {
|
||||
redirectTo = wechatOAuthDefaultRedirectTo
|
||||
}
|
||||
browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
|
||||
if strings.TrimSpace(browserSessionKey) == "" {
|
||||
redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
|
||||
return
|
||||
}
|
||||
|
||||
intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
|
||||
mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
|
||||
if err != nil || strings.TrimSpace(mode) == "" {
|
||||
redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return
|
||||
}
|
||||
|
||||
tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
|
||||
if err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
|
||||
openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
|
||||
providerSubject := firstNonEmpty(unionid, openid)
|
||||
if providerSubject == "" {
|
||||
redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "")
|
||||
return
|
||||
}
|
||||
|
||||
username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
|
||||
email := wechatSyntheticEmail(providerSubject)
|
||||
upstreamClaims := map[string]any{
|
||||
"email": email,
|
||||
"username": username,
|
||||
"subject": providerSubject,
|
||||
"openid": openid,
|
||||
"unionid": unionid,
|
||||
"mode": cfg.mode,
|
||||
"suggested_display_name": strings.TrimSpace(userInfo.Nickname),
|
||||
"suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
|
||||
}
|
||||
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
|
||||
if err != nil {
|
||||
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil {
|
||||
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
|
||||
return
|
||||
}
|
||||
redirectToFrontendCallback(c, frontendCallback)
|
||||
}
|
||||
|
||||
type completeWeChatOAuthRequest struct {
|
||||
InvitationCode string `json:"invitation_code" binding:"required"`
|
||||
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
|
||||
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
|
||||
// validating the invitation code and consuming the current pending browser session.
|
||||
// POST /api/v1/auth/oauth/wechat/complete-registration
|
||||
func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
var req completeWeChatOAuthRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
secureCookie := isRequestHTTPS(c)
|
||||
sessionToken, err := readOAuthPendingSessionCookie(c)
|
||||
if err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
|
||||
return
|
||||
}
|
||||
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
|
||||
if err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
|
||||
return
|
||||
}
|
||||
pendingSvc, err := h.pendingIdentityService()
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
|
||||
if err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(session.ResolvedEmail)
|
||||
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
|
||||
if email == "" || username == "" {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
|
||||
AdoptDisplayName: req.AdoptDisplayName,
|
||||
AdoptAvatar: req.AdoptAvatar,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
|
||||
return
|
||||
}
|
||||
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
clearOAuthPendingSessionCookie(c, secureCookie)
|
||||
clearOAuthPendingBrowserCookie(c, secureCookie)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": tokenPair.AccessToken,
|
||||
"refresh_token": tokenPair.RefreshToken,
|
||||
"expires_in": tokenPair.ExpiresIn,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) createWeChatPendingSession(
|
||||
c *gin.Context,
|
||||
intent string,
|
||||
providerSubject string,
|
||||
email string,
|
||||
redirectTo string,
|
||||
browserSessionKey string,
|
||||
upstreamClaims map[string]any,
|
||||
tokenPair *service.TokenPair,
|
||||
authErr error,
|
||||
) error {
|
||||
completionResponse := map[string]any{
|
||||
"redirect": redirectTo,
|
||||
}
|
||||
if authErr != nil {
|
||||
if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
|
||||
completionResponse["error"] = "invitation_required"
|
||||
} else {
|
||||
return authErr
|
||||
}
|
||||
} else if tokenPair != nil {
|
||||
completionResponse["access_token"] = tokenPair.AccessToken
|
||||
completionResponse["refresh_token"] = tokenPair.RefreshToken
|
||||
completionResponse["expires_in"] = tokenPair.ExpiresIn
|
||||
completionResponse["token_type"] = "Bearer"
|
||||
}
|
||||
|
||||
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
|
||||
Intent: intent,
|
||||
Identity: service.PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: wechatOAuthProviderKey,
|
||||
ProviderSubject: providerSubject,
|
||||
},
|
||||
ResolvedEmail: email,
|
||||
RedirectTo: redirectTo,
|
||||
BrowserSessionKey: browserSessionKey,
|
||||
UpstreamIdentityClaims: upstreamClaims,
|
||||
CompletionResponse: completionResponse,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
|
||||
mode, err := resolveWeChatOAuthMode(rawMode, c)
|
||||
if err != nil {
|
||||
return wechatOAuthConfig{}, err
|
||||
}
|
||||
|
||||
apiBaseURL := ""
|
||||
if h != nil && h.settingSvc != nil {
|
||||
settings, err := h.settingSvc.GetAllSettings(ctx)
|
||||
if err == nil && settings != nil {
|
||||
apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
|
||||
}
|
||||
}
|
||||
|
||||
cfg := wechatOAuthConfig{
|
||||
mode: mode,
|
||||
redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"),
|
||||
frontendCallback: wechatOAuthFrontendCallback(),
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case "mp":
|
||||
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
|
||||
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
|
||||
cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
|
||||
cfg.scope = "snsapi_userinfo"
|
||||
default:
|
||||
cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
|
||||
cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
|
||||
cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
|
||||
cfg.scope = "snsapi_login"
|
||||
}
|
||||
|
||||
if cfg.appID == "" || cfg.appSecret == "" {
|
||||
return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
|
||||
}
|
||||
if strings.TrimSpace(cfg.redirectURI) == "" {
|
||||
return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func wechatOAuthFrontendCallback() string {
|
||||
return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB)
|
||||
}
|
||||
|
||||
func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
|
||||
mode := strings.ToLower(strings.TrimSpace(rawMode))
|
||||
if mode == "" {
|
||||
if isWeChatBrowserRequest(c) {
|
||||
return "mp", nil
|
||||
}
|
||||
return "open", nil
|
||||
}
|
||||
if mode != "open" && mode != "mp" {
|
||||
return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
|
||||
}
|
||||
return mode, nil
|
||||
}
|
||||
|
||||
func isWeChatBrowserRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
|
||||
}
|
||||
|
||||
func normalizeWeChatOAuthIntent(raw string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "", "login":
|
||||
return wechatOAuthIntentLogin
|
||||
case "bind", "bind_current_user":
|
||||
return wechatOAuthIntentBind
|
||||
case "adopt", "adopt_existing_user_by_email":
|
||||
return wechatOAuthIntentAdoptEmail
|
||||
default:
|
||||
return wechatOAuthIntentLogin
|
||||
}
|
||||
}
|
||||
|
||||
func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
|
||||
u, err := url.Parse(cfg.authorizeURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse authorize url: %w", err)
|
||||
}
|
||||
query := u.Query()
|
||||
query.Set("appid", cfg.appID)
|
||||
query.Set("redirect_uri", cfg.redirectURI)
|
||||
query.Set("response_type", "code")
|
||||
query.Set("scope", cfg.scope)
|
||||
query.Set("state", state)
|
||||
u.RawQuery = query.Encode()
|
||||
u.Fragment = "wechat_redirect"
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
|
||||
callbackPath = strings.TrimSpace(callbackPath)
|
||||
if callbackPath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if raw := strings.TrimSpace(apiBaseURL); raw != "" {
|
||||
if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
|
||||
basePath := strings.TrimRight(parsed.EscapedPath(), "/")
|
||||
targetPath := callbackPath
|
||||
if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
|
||||
targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
|
||||
} else if basePath != "" {
|
||||
targetPath = basePath + callbackPath
|
||||
}
|
||||
return parsed.Scheme + "://" + parsed.Host + targetPath
|
||||
}
|
||||
}
|
||||
|
||||
if c == nil || c.Request == nil {
|
||||
return ""
|
||||
}
|
||||
scheme := "http"
|
||||
if isRequestHTTPS(c) {
|
||||
scheme = "https"
|
||||
}
|
||||
host := strings.TrimSpace(c.Request.Host)
|
||||
if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
return scheme + "://" + host + callbackPath
|
||||
}
|
||||
|
||||
func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
|
||||
tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tokenResp, userInfo, nil
|
||||
}
|
||||
|
||||
func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
|
||||
endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wechat access token url: %w", err)
|
||||
}
|
||||
|
||||
query := endpoint.Query()
|
||||
query.Set("appid", cfg.appID)
|
||||
query.Set("secret", cfg.appSecret)
|
||||
query.Set("code", strings.TrimSpace(code))
|
||||
query.Set("grant_type", "authorization_code")
|
||||
endpoint.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build wechat access token request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request wechat access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wechat access token response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp wechatOAuthTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("decode wechat access token response: %w", err)
|
||||
}
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
|
||||
}
|
||||
if strings.TrimSpace(tokenResp.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("wechat access token missing access_token")
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
|
||||
if tokenResp == nil {
|
||||
return nil, fmt.Errorf("wechat token response is nil")
|
||||
}
|
||||
|
||||
endpoint, err := url.Parse(wechatOAuthUserInfoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
|
||||
}
|
||||
query := endpoint.Query()
|
||||
query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
|
||||
query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
|
||||
query.Set("lang", "zh_CN")
|
||||
endpoint.RawQuery = query.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build wechat userinfo request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request wechat userinfo: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wechat userinfo response: %w", err)
|
||||
}
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var userInfo wechatOAuthUserInfoResponse
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
|
||||
}
|
||||
if userInfo.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
|
||||
}
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
func wechatSyntheticEmail(subject string) string {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return ""
|
||||
}
|
||||
return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
|
||||
}
|
||||
|
||||
func wechatFallbackUsername(subject string) string {
|
||||
subject = strings.TrimSpace(subject)
|
||||
if subject == "" {
|
||||
return "wechat_user"
|
||||
}
|
||||
return "wechat_" + truncateFragmentValue(subject)
|
||||
}
|
||||
|
||||
func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: wechatOAuthCookiePath,
|
||||
MaxAge: maxAgeSec,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
func wechatClearCookie(c *gin.Context, name string, secure bool) {
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: wechatOAuthCookiePath,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
411
backend/internal/handler/auth_wechat_oauth_test.go
Normal file
411
backend/internal/handler/auth_wechat_oauth_test.go
Normal file
@@ -0,0 +1,411 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
|
||||
c.Request.Host = "api.example.com"
|
||||
|
||||
handler := &AuthHandler{}
|
||||
handler.WeChatOAuthStart(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.NotEmpty(t, location)
|
||||
require.Contains(t, location, "open.weixin.qq.com")
|
||||
require.Contains(t, location, "appid=wx-open-app")
|
||||
require.Contains(t, location, "scope=snsapi_login")
|
||||
|
||||
cookies := recorder.Result().Cookies()
|
||||
require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
|
||||
require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
|
||||
require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
|
||||
require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
|
||||
}
|
||||
|
||||
func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
||||
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
||||
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||
t.Cleanup(func() {
|
||||
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
||||
wechatOAuthUserInfoURL = originalUserInfoURL
|
||||
})
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
||||
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
||||
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
||||
|
||||
handler, client := newWeChatOAuthTestHandler(t, false)
|
||||
defer client.Close()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
||||
req.Host = "api.example.com"
|
||||
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
||||
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||
c.Request = req
|
||||
|
||||
handler.WeChatOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
|
||||
|
||||
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||
require.NotNil(t, sessionCookie)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "wechat", session.ProviderType)
|
||||
require.Equal(t, "wechat-main", session.ProviderKey)
|
||||
require.Equal(t, "union-456", session.ProviderSubject)
|
||||
require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
|
||||
require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
|
||||
require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
|
||||
require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
|
||||
require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
|
||||
}
|
||||
|
||||
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
|
||||
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
|
||||
t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback")
|
||||
|
||||
originalAccessTokenURL := wechatOAuthAccessTokenURL
|
||||
originalUserInfoURL := wechatOAuthUserInfoURL
|
||||
t.Cleanup(func() {
|
||||
wechatOAuthAccessTokenURL = originalAccessTokenURL
|
||||
wechatOAuthUserInfoURL = originalUserInfoURL
|
||||
})
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
|
||||
case strings.Contains(r.URL.Path, "/sns/userinfo"):
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer upstream.Close()
|
||||
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
|
||||
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
|
||||
|
||||
handler, client := newWeChatOAuthTestHandler(t, true)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
redeemRepo := repository.NewRedeemCodeRepository(client)
|
||||
require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
|
||||
Code: "invite-1",
|
||||
Type: service.RedeemTypeInvitation,
|
||||
Status: service.StatusUnused,
|
||||
}))
|
||||
|
||||
callbackRecorder := httptest.NewRecorder()
|
||||
callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
|
||||
callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
|
||||
callbackReq.Host = "api.example.com"
|
||||
callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
|
||||
callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
|
||||
callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
|
||||
callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
|
||||
callbackCtx.Request = callbackReq
|
||||
|
||||
handler.WeChatOAuthCallback(callbackCtx)
|
||||
|
||||
require.Equal(t, http.StatusFound, callbackRecorder.Code)
|
||||
require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
|
||||
|
||||
sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
|
||||
require.NotNil(t, sessionCookie)
|
||||
sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
|
||||
|
||||
pendingSession, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"])
|
||||
|
||||
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
|
||||
completeRecorder := httptest.NewRecorder()
|
||||
completeCtx, _ := gin.CreateTestContext(completeRecorder)
|
||||
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
|
||||
completeReq.Header.Set("Content-Type", "application/json")
|
||||
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
|
||||
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
|
||||
completeCtx.Request = completeReq
|
||||
|
||||
handler.CompleteWeChatOAuthRegistration(completeCtx)
|
||||
|
||||
require.Equal(t, http.StatusOK, completeRecorder.Code)
|
||||
responseData := decodeJSONBody(t, completeRecorder)
|
||||
require.NotEmpty(t, responseData["access_token"])
|
||||
|
||||
userEntity, err := client.User.Query().
|
||||
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "WeChat Display", userEntity.Username)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("wechat"),
|
||||
authidentity.ProviderKeyEQ("wechat-main"),
|
||||
authidentity.ProviderSubjectEQ("union-456"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userEntity.ID, identity.UserID)
|
||||
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
|
||||
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
|
||||
|
||||
decision, err := client.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decision.IdentityID)
|
||||
require.Equal(t, identity.ID, *decision.IdentityID)
|
||||
require.True(t, decision.AdoptDisplayName)
|
||||
require.True(t, decision.AdoptAvatar)
|
||||
|
||||
consumed, err := client.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.IDEQ(pendingSession.ID)).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
}
|
||||
|
||||
func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
|
||||
userRepo := &oauthPendingFlowUserRepo{client: client}
|
||||
redeemRepo := repository.NewRedeemCodeRepository(client)
|
||||
settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
|
||||
},
|
||||
}, &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
})
|
||||
|
||||
authSvc := service.NewAuthService(
|
||||
client,
|
||||
userRepo,
|
||||
redeemRepo,
|
||||
&wechatOAuthRefreshTokenCacheStub{},
|
||||
&config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
AccessTokenExpireMinutes: 60,
|
||||
RefreshTokenExpireDays: 7,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
},
|
||||
settingSvc,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
authService: authSvc,
|
||||
settingSvc: settingSvc,
|
||||
}, client
|
||||
}
|
||||
|
||||
func encodedCookie(name, value string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: encodeCookieValue(value),
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == name {
|
||||
return cookie
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeCookieValueForTest(t *testing.T, value string) string {
|
||||
t.Helper()
|
||||
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||
require.NoError(t, err)
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
type wechatOAuthSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
value, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
result := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
result := make(map[string]string, len(s.values))
|
||||
for key, value := range s.values {
|
||||
result[key] = value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type wechatOAuthRefreshTokenCacheStub struct{}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
@@ -189,6 +189,7 @@ type PublicSettings struct {
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
|
||||
SoraClientEnabled bool `json:"sora_client_enabled"`
|
||||
|
||||
@@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
|
||||
// This allows looking up the correct provider instance before verification.
|
||||
func extractOutTradeNo(rawBody, providerKey string) string {
|
||||
switch providerKey {
|
||||
case payment.TypeEasyPay:
|
||||
case payment.TypeEasyPay, payment.TypeAlipay:
|
||||
values, err := url.ParseQuery(rawBody)
|
||||
if err == nil {
|
||||
return values.Get("out_trade_no")
|
||||
|
||||
@@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) {
|
||||
assert.Equal(t, 200, webhookLogTruncateLen)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractOutTradeNo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerKey string
|
||||
rawBody string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "easypay query payload",
|
||||
providerKey: "easypay",
|
||||
rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
|
||||
want: "sub2_123",
|
||||
},
|
||||
{
|
||||
name: "alipay query payload",
|
||||
providerKey: "alipay",
|
||||
rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
|
||||
want: "sub2_456",
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
providerKey: "wxpay",
|
||||
rawBody: "{}",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
|
||||
@@ -34,10 +34,16 @@ type ChangePasswordRequest struct {
|
||||
// UpdateProfileRequest represents the update profile request payload
|
||||
type UpdateProfileRequest struct {
|
||||
Username *string `json:"username"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
|
||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
|
||||
}
|
||||
|
||||
type userProfileResponse struct {
|
||||
dto.User
|
||||
AvatarURL string `json:"avatar_url,omitempty"`
|
||||
}
|
||||
|
||||
// GetProfile handles getting user profile
|
||||
// GET /api/v1/users/me
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
@@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||
userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
response.Success(c, userProfileResponseFromService(userData))
|
||||
}
|
||||
|
||||
// ChangePassword handles changing user password
|
||||
@@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
|
||||
svcReq := service.UpdateProfileRequest{
|
||||
Username: req.Username,
|
||||
AvatarURL: req.AvatarURL,
|
||||
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
|
||||
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
|
||||
}
|
||||
@@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
response.Success(c, userProfileResponseFromService(updatedUser))
|
||||
}
|
||||
|
||||
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
|
||||
@@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
response.Success(c, userProfileResponseFromService(updatedUser))
|
||||
}
|
||||
|
||||
// RemoveNotifyEmailRequest represents the request to remove a notify email
|
||||
@@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
response.Success(c, userProfileResponseFromService(updatedUser))
|
||||
}
|
||||
|
||||
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
|
||||
@@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
response.Success(c, userProfileResponseFromService(updatedUser))
|
||||
}
|
||||
|
||||
func userProfileResponseFromService(user *service.User) userProfileResponse {
|
||||
base := dto.UserFromService(user)
|
||||
if base == nil {
|
||||
return userProfileResponse{}
|
||||
}
|
||||
return userProfileResponse{
|
||||
User: *base,
|
||||
AvatarURL: user.AvatarURL,
|
||||
}
|
||||
}
|
||||
|
||||
136
backend/internal/handler/user_handler_test.go
Normal file
136
backend/internal/handler/user_handler_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userHandlerRepoStub struct {
|
||||
user *service.User
|
||||
}
|
||||
|
||||
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
|
||||
func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
|
||||
cloned := *s.user
|
||||
return &cloned, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
|
||||
cloned := *s.user
|
||||
return &cloned, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
|
||||
cloned := *s.user
|
||||
return &cloned, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
|
||||
cloned := *user
|
||||
s.user = &cloned
|
||||
return nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
|
||||
func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
|
||||
if s.user == nil || s.user.AvatarURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return &service.UserAvatar{
|
||||
StorageProvider: s.user.AvatarSource,
|
||||
URL: s.user.AvatarURL,
|
||||
ContentType: s.user.AvatarMIME,
|
||||
ByteSize: s.user.AvatarByteSize,
|
||||
SHA256: s.user.AvatarSHA256,
|
||||
}, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
|
||||
s.user.AvatarURL = input.URL
|
||||
s.user.AvatarSource = input.StorageProvider
|
||||
s.user.AvatarMIME = input.ContentType
|
||||
s.user.AvatarByteSize = input.ByteSize
|
||||
s.user.AvatarSHA256 = input.SHA256
|
||||
return &service.UserAvatar{
|
||||
StorageProvider: input.StorageProvider,
|
||||
URL: input.URL,
|
||||
ContentType: input.ContentType,
|
||||
ByteSize: input.ByteSize,
|
||||
SHA256: input.SHA256,
|
||||
}, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
|
||||
s.user.AvatarURL = ""
|
||||
s.user.AvatarSource = ""
|
||||
s.user.AvatarMIME = ""
|
||||
s.user.AvatarByteSize = 0
|
||||
s.user.AvatarSHA256 = ""
|
||||
return nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||
return nil
|
||||
}
|
||||
func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||
func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
|
||||
func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
|
||||
|
||||
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
repo := &userHandlerRepoStub{
|
||||
user: &service.User{
|
||||
ID: 11,
|
||||
Email: "handler-avatar@example.com",
|
||||
Username: "handler-avatar",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}
|
||||
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
|
||||
|
||||
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
|
||||
|
||||
handler.UpdateProfile(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Username string `json:"username"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
|
||||
require.Equal(t, "handler-avatar", resp.Data.Username)
|
||||
}
|
||||
Reference in New Issue
Block a user