feat: add profile auth identity binding flow
This commit is contained in:
@@ -7,13 +7,13 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -24,6 +24,8 @@ var (
|
||||
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
|
||||
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
|
||||
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
||||
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
|
||||
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,6 +35,8 @@ const (
|
||||
// User-level rate limiting for notify email verification codes
|
||||
notifyCodeUserRateLimit = 5
|
||||
notifyCodeUserRateWindow = 10 * time.Minute
|
||||
|
||||
defaultUserIdentityRedirect = "/settings/profile"
|
||||
)
|
||||
|
||||
// UserListFilters contains all filter options for listing users
|
||||
@@ -71,6 +75,7 @@ type UserRepository interface {
|
||||
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
|
||||
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
|
||||
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
|
||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||
|
||||
// TOTP 双因素认证
|
||||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||||
@@ -78,6 +83,50 @@ type UserRepository interface {
|
||||
DisableTotp(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
type UserAuthIdentityRecord struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
ProviderSubject string
|
||||
VerifiedAt *time.Time
|
||||
Issuer *string
|
||||
Metadata map[string]any
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type UserIdentitySummary struct {
|
||||
Provider string `json:"provider"`
|
||||
Bound bool `json:"bound"`
|
||||
BoundCount int `json:"bound_count"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
SubjectHint string `json:"subject_hint,omitempty"`
|
||||
ProviderKey string `json:"provider_key,omitempty"`
|
||||
VerifiedAt *time.Time `json:"verified_at,omitempty"`
|
||||
BindStartPath string `json:"bind_start_path,omitempty"`
|
||||
CanBind bool `json:"can_bind"`
|
||||
CanUnbind bool `json:"can_unbind"`
|
||||
Note string `json:"note,omitempty"`
|
||||
}
|
||||
|
||||
type UserIdentitySummarySet struct {
|
||||
Email UserIdentitySummary `json:"email"`
|
||||
LinuxDo UserIdentitySummary `json:"linuxdo"`
|
||||
OIDC UserIdentitySummary `json:"oidc"`
|
||||
WeChat UserIdentitySummary `json:"wechat"`
|
||||
}
|
||||
|
||||
type StartUserIdentityBindingRequest struct {
|
||||
Provider string
|
||||
RedirectTo string
|
||||
}
|
||||
|
||||
type StartUserIdentityBindingResult struct {
|
||||
Provider string `json:"provider"`
|
||||
AuthorizeURL string `json:"authorize_url"`
|
||||
Method string `json:"method"`
|
||||
UseBrowserRedirect bool `json:"use_browser_redirect"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
type UpdateProfileRequest struct {
|
||||
Email *string `json:"email"`
|
||||
@@ -106,6 +155,10 @@ type UpsertUserAvatarInput struct {
|
||||
SHA256 string
|
||||
}
|
||||
|
||||
type userAuthIdentityReader interface {
|
||||
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
|
||||
}
|
||||
|
||||
// ChangePasswordRequest 修改密码请求
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
@@ -151,6 +204,47 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) {
|
||||
if user == nil {
|
||||
var err error
|
||||
user, err = s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
records, err := s.listUserAuthIdentities(ctx, userID)
|
||||
if err != nil {
|
||||
return UserIdentitySummarySet{}, err
|
||||
}
|
||||
|
||||
return UserIdentitySummarySet{
|
||||
Email: s.buildEmailIdentitySummary(user),
|
||||
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
|
||||
OIDC: s.buildProviderIdentitySummary("oidc", records),
|
||||
WeChat: s.buildProviderIdentitySummary("wechat", records),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) {
|
||||
provider := normalizeUserIdentityProvider(req.Provider)
|
||||
if provider == "" {
|
||||
return nil, ErrIdentityProviderInvalid
|
||||
}
|
||||
|
||||
authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StartUserIdentityBindingResult{
|
||||
Provider: provider,
|
||||
AuthorizeURL: authorizeURL,
|
||||
Method: "GET",
|
||||
UseBrowserRedirect: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户资料
|
||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
@@ -303,6 +397,234 @@ func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
|
||||
summary := UserIdentitySummary{
|
||||
Provider: "email",
|
||||
CanBind: false,
|
||||
CanUnbind: false,
|
||||
Note: "Primary account email is managed from the profile form.",
|
||||
}
|
||||
if user == nil {
|
||||
return summary
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(user.Email)
|
||||
if email == "" || isReservedEmail(email) {
|
||||
return summary
|
||||
}
|
||||
|
||||
summary.Bound = true
|
||||
summary.BoundCount = 1
|
||||
summary.DisplayName = email
|
||||
summary.SubjectHint = maskEmailIdentity(email)
|
||||
summary.ProviderKey = "email"
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s *UserService) buildProviderIdentitySummary(provider string, records []UserAuthIdentityRecord) UserIdentitySummary {
|
||||
summary := UserIdentitySummary{
|
||||
Provider: provider,
|
||||
CanUnbind: false,
|
||||
}
|
||||
filtered := filterUserAuthIdentities(records, provider)
|
||||
if len(filtered) == 0 {
|
||||
summary.CanBind = true
|
||||
bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "")
|
||||
if err == nil {
|
||||
summary.BindStartPath = bindStartPath
|
||||
}
|
||||
return summary
|
||||
}
|
||||
|
||||
primary := selectPrimaryUserAuthIdentity(filtered)
|
||||
summary.Bound = true
|
||||
summary.BoundCount = len(filtered)
|
||||
summary.DisplayName = userAuthIdentityDisplayName(primary)
|
||||
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
|
||||
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
|
||||
summary.VerifiedAt = primary.VerifiedAt
|
||||
summary.Note = "Unbind is not available yet."
|
||||
return summary
|
||||
}
|
||||
|
||||
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
|
||||
if userID <= 0 || s == nil || s.userRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.userRepo.ListUserAuthIdentities(ctx, userID)
|
||||
}
|
||||
|
||||
func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) {
|
||||
provider = normalizeUserIdentityProvider(provider)
|
||||
if provider == "" || provider == "email" {
|
||||
return "", ErrIdentityProviderInvalid
|
||||
}
|
||||
|
||||
redirectTo, err := normalizeUserIdentityRedirect(redirectTo)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
path := ""
|
||||
switch provider {
|
||||
case "linuxdo":
|
||||
path = "/api/v1/auth/oauth/linuxdo/start"
|
||||
case "oidc":
|
||||
path = "/api/v1/auth/oauth/oidc/start"
|
||||
case "wechat":
|
||||
path = "/api/v1/auth/oauth/wechat/start"
|
||||
default:
|
||||
return "", ErrIdentityProviderInvalid
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("redirect", redirectTo)
|
||||
query.Set("intent", "bind_current_user")
|
||||
return path + "?" + query.Encode(), nil
|
||||
}
|
||||
|
||||
func normalizeUserIdentityProvider(provider string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "linuxdo":
|
||||
return "linuxdo"
|
||||
case "oidc":
|
||||
return "oidc"
|
||||
case "wechat":
|
||||
return "wechat"
|
||||
case "email":
|
||||
return "email"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeUserIdentityRedirect(raw string) (string, error) {
|
||||
redirect := strings.TrimSpace(raw)
|
||||
if redirect == "" {
|
||||
return defaultUserIdentityRedirect, nil
|
||||
}
|
||||
if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||
return "", ErrIdentityRedirectInvalid
|
||||
}
|
||||
return redirect, nil
|
||||
}
|
||||
|
||||
func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord {
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
filtered := make([]UserAuthIdentityRecord, 0, len(records))
|
||||
for _, record := range records {
|
||||
if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) {
|
||||
filtered = append(filtered, record)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord {
|
||||
if len(records) == 0 {
|
||||
return UserAuthIdentityRecord{}
|
||||
}
|
||||
sort.SliceStable(records, func(i, j int) bool {
|
||||
left := userAuthIdentitySortTime(records[i])
|
||||
right := userAuthIdentitySortTime(records[j])
|
||||
if !left.Equal(right) {
|
||||
return left.After(right)
|
||||
}
|
||||
return records[i].ProviderKey < records[j].ProviderKey
|
||||
})
|
||||
return records[0]
|
||||
}
|
||||
|
||||
func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time {
|
||||
if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() {
|
||||
return record.VerifiedAt.UTC()
|
||||
}
|
||||
if !record.UpdatedAt.IsZero() {
|
||||
return record.UpdatedAt.UTC()
|
||||
}
|
||||
if !record.CreatedAt.IsZero() {
|
||||
return record.CreatedAt.UTC()
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string {
|
||||
if displayName := firstStringIdentityValue(record.Metadata,
|
||||
"display_name",
|
||||
"suggested_display_name",
|
||||
"username",
|
||||
"name",
|
||||
"nickname",
|
||||
"email",
|
||||
); displayName != "" {
|
||||
return displayName
|
||||
}
|
||||
if subject := strings.TrimSpace(record.ProviderSubject); subject != "" {
|
||||
return subject
|
||||
}
|
||||
return strings.TrimSpace(record.ProviderType)
|
||||
}
|
||||
|
||||
func firstStringIdentityValue(values map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
raw, ok := values[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
case fmt.Stringer:
|
||||
if trimmed := strings.TrimSpace(value.String()); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func maskEmailIdentity(email string) string {
|
||||
local, domain, ok := strings.Cut(strings.TrimSpace(email), "@")
|
||||
if !ok || local == "" || domain == "" {
|
||||
return maskOpaqueIdentity(email)
|
||||
}
|
||||
runes := []rune(local)
|
||||
if len(runes) == 1 {
|
||||
return string(runes[0]) + "***@" + domain
|
||||
}
|
||||
return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain
|
||||
}
|
||||
|
||||
func maskOpaqueIdentity(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
runes := []rune(value)
|
||||
switch {
|
||||
case len(runes) == 0:
|
||||
return ""
|
||||
case len(runes) <= 4:
|
||||
return string(runes[0]) + "***"
|
||||
case len(runes) <= 8:
|
||||
return string(runes[:2]) + "***" + string(runes[len(runes)-1:])
|
||||
default:
|
||||
return string(runes[:3]) + "***" + string(runes[len(runes)-3:])
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnyMap(values map[string]any) map[string]any {
|
||||
if len(values) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
cloned := make(map[string]any, len(values))
|
||||
for key, value := range values {
|
||||
cloned[key] = value
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user