feat: add profile auth identity binding flow

This commit is contained in:
IanShaw027
2026-04-20 18:28:44 +08:00
parent 13d9780df4
commit c6d8592484
31 changed files with 3419 additions and 239 deletions

View File

@@ -7,13 +7,13 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"log/slog"
"net/url"
"sort"
"strings"
"time"
)
var (
@@ -24,6 +24,8 @@ var (
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
ErrIdentityProviderInvalid = infraerrors.BadRequest("IDENTITY_PROVIDER_INVALID", "identity provider is invalid")
ErrIdentityRedirectInvalid = infraerrors.BadRequest("IDENTITY_REDIRECT_INVALID", "identity redirect path is invalid")
)
const (
@@ -33,6 +35,8 @@ const (
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
defaultUserIdentityRedirect = "/settings/profile"
)
// UserListFilters contains all filter options for listing users
@@ -71,6 +75,7 @@ type UserRepository interface {
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// RemoveGroupFromUserAllowedGroups 移除单个用户的指定分组权限
RemoveGroupFromUserAllowedGroups(ctx context.Context, userID int64, groupID int64) error
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
@@ -78,6 +83,50 @@ type UserRepository interface {
DisableTotp(ctx context.Context, userID int64) error
}
type UserAuthIdentityRecord struct {
ProviderType string
ProviderKey string
ProviderSubject string
VerifiedAt *time.Time
Issuer *string
Metadata map[string]any
CreatedAt time.Time
UpdatedAt time.Time
}
type UserIdentitySummary struct {
Provider string `json:"provider"`
Bound bool `json:"bound"`
BoundCount int `json:"bound_count"`
DisplayName string `json:"display_name,omitempty"`
SubjectHint string `json:"subject_hint,omitempty"`
ProviderKey string `json:"provider_key,omitempty"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
BindStartPath string `json:"bind_start_path,omitempty"`
CanBind bool `json:"can_bind"`
CanUnbind bool `json:"can_unbind"`
Note string `json:"note,omitempty"`
}
type UserIdentitySummarySet struct {
Email UserIdentitySummary `json:"email"`
LinuxDo UserIdentitySummary `json:"linuxdo"`
OIDC UserIdentitySummary `json:"oidc"`
WeChat UserIdentitySummary `json:"wechat"`
}
type StartUserIdentityBindingRequest struct {
Provider string
RedirectTo string
}
type StartUserIdentityBindingResult struct {
Provider string `json:"provider"`
AuthorizeURL string `json:"authorize_url"`
Method string `json:"method"`
UseBrowserRedirect bool `json:"use_browser_redirect"`
}
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
@@ -106,6 +155,10 @@ type UpsertUserAvatarInput struct {
SHA256 string
}
type userAuthIdentityReader interface {
ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error)
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -151,6 +204,47 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
return user, nil
}
func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID int64, user *User) (UserIdentitySummarySet, error) {
if user == nil {
var err error
user, err = s.userRepo.GetByID(ctx, userID)
if err != nil {
return UserIdentitySummarySet{}, fmt.Errorf("get user: %w", err)
}
}
records, err := s.listUserAuthIdentities(ctx, userID)
if err != nil {
return UserIdentitySummarySet{}, err
}
return UserIdentitySummarySet{
Email: s.buildEmailIdentitySummary(user),
LinuxDo: s.buildProviderIdentitySummary("linuxdo", records),
OIDC: s.buildProviderIdentitySummary("oidc", records),
WeChat: s.buildProviderIdentitySummary("wechat", records),
}, nil
}
func (s *UserService) PrepareIdentityBindingStart(_ context.Context, req StartUserIdentityBindingRequest) (*StartUserIdentityBindingResult, error) {
provider := normalizeUserIdentityProvider(req.Provider)
if provider == "" {
return nil, ErrIdentityProviderInvalid
}
authorizeURL, err := buildUserIdentityBindAuthorizeURL(provider, req.RedirectTo)
if err != nil {
return nil, err
}
return &StartUserIdentityBindingResult{
Provider: provider,
AuthorizeURL: authorizeURL,
Method: "GET",
UseBrowserRedirect: true,
}, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
@@ -303,6 +397,234 @@ func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
}, nil
}
func (s *UserService) buildEmailIdentitySummary(user *User) UserIdentitySummary {
summary := UserIdentitySummary{
Provider: "email",
CanBind: false,
CanUnbind: false,
Note: "Primary account email is managed from the profile form.",
}
if user == nil {
return summary
}
email := strings.TrimSpace(user.Email)
if email == "" || isReservedEmail(email) {
return summary
}
summary.Bound = true
summary.BoundCount = 1
summary.DisplayName = email
summary.SubjectHint = maskEmailIdentity(email)
summary.ProviderKey = "email"
return summary
}
func (s *UserService) buildProviderIdentitySummary(provider string, records []UserAuthIdentityRecord) UserIdentitySummary {
summary := UserIdentitySummary{
Provider: provider,
CanUnbind: false,
}
filtered := filterUserAuthIdentities(records, provider)
if len(filtered) == 0 {
summary.CanBind = true
bindStartPath, err := buildUserIdentityBindAuthorizeURL(provider, "")
if err == nil {
summary.BindStartPath = bindStartPath
}
return summary
}
primary := selectPrimaryUserAuthIdentity(filtered)
summary.Bound = true
summary.BoundCount = len(filtered)
summary.DisplayName = userAuthIdentityDisplayName(primary)
summary.SubjectHint = maskOpaqueIdentity(primary.ProviderSubject)
summary.ProviderKey = strings.TrimSpace(primary.ProviderKey)
summary.VerifiedAt = primary.VerifiedAt
summary.Note = "Unbind is not available yet."
return summary
}
func (s *UserService) listUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
if userID <= 0 || s == nil || s.userRepo == nil {
return nil, nil
}
return s.userRepo.ListUserAuthIdentities(ctx, userID)
}
func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, error) {
provider = normalizeUserIdentityProvider(provider)
if provider == "" || provider == "email" {
return "", ErrIdentityProviderInvalid
}
redirectTo, err := normalizeUserIdentityRedirect(redirectTo)
if err != nil {
return "", err
}
path := ""
switch provider {
case "linuxdo":
path = "/api/v1/auth/oauth/linuxdo/start"
case "oidc":
path = "/api/v1/auth/oauth/oidc/start"
case "wechat":
path = "/api/v1/auth/oauth/wechat/start"
default:
return "", ErrIdentityProviderInvalid
}
query := url.Values{}
query.Set("redirect", redirectTo)
query.Set("intent", "bind_current_user")
return path + "?" + query.Encode(), nil
}
func normalizeUserIdentityProvider(provider string) string {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "linuxdo":
return "linuxdo"
case "oidc":
return "oidc"
case "wechat":
return "wechat"
case "email":
return "email"
default:
return ""
}
}
func normalizeUserIdentityRedirect(raw string) (string, error) {
redirect := strings.TrimSpace(raw)
if redirect == "" {
return defaultUserIdentityRedirect, nil
}
if len(redirect) > 2048 || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
return "", ErrIdentityRedirectInvalid
}
return redirect, nil
}
func filterUserAuthIdentities(records []UserAuthIdentityRecord, provider string) []UserAuthIdentityRecord {
if len(records) == 0 {
return nil
}
filtered := make([]UserAuthIdentityRecord, 0, len(records))
for _, record := range records {
if strings.EqualFold(strings.TrimSpace(record.ProviderType), provider) {
filtered = append(filtered, record)
}
}
return filtered
}
func selectPrimaryUserAuthIdentity(records []UserAuthIdentityRecord) UserAuthIdentityRecord {
if len(records) == 0 {
return UserAuthIdentityRecord{}
}
sort.SliceStable(records, func(i, j int) bool {
left := userAuthIdentitySortTime(records[i])
right := userAuthIdentitySortTime(records[j])
if !left.Equal(right) {
return left.After(right)
}
return records[i].ProviderKey < records[j].ProviderKey
})
return records[0]
}
func userAuthIdentitySortTime(record UserAuthIdentityRecord) time.Time {
if record.VerifiedAt != nil && !record.VerifiedAt.IsZero() {
return record.VerifiedAt.UTC()
}
if !record.UpdatedAt.IsZero() {
return record.UpdatedAt.UTC()
}
if !record.CreatedAt.IsZero() {
return record.CreatedAt.UTC()
}
return time.Time{}
}
func userAuthIdentityDisplayName(record UserAuthIdentityRecord) string {
if displayName := firstStringIdentityValue(record.Metadata,
"display_name",
"suggested_display_name",
"username",
"name",
"nickname",
"email",
); displayName != "" {
return displayName
}
if subject := strings.TrimSpace(record.ProviderSubject); subject != "" {
return subject
}
return strings.TrimSpace(record.ProviderType)
}
func firstStringIdentityValue(values map[string]any, keys ...string) string {
for _, key := range keys {
raw, ok := values[key]
if !ok {
continue
}
switch value := raw.(type) {
case string:
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
case fmt.Stringer:
if trimmed := strings.TrimSpace(value.String()); trimmed != "" {
return trimmed
}
}
}
return ""
}
func maskEmailIdentity(email string) string {
local, domain, ok := strings.Cut(strings.TrimSpace(email), "@")
if !ok || local == "" || domain == "" {
return maskOpaqueIdentity(email)
}
runes := []rune(local)
if len(runes) == 1 {
return string(runes[0]) + "***@" + domain
}
return string(runes[0]) + "***" + string(runes[len(runes)-1]) + "@" + domain
}
func maskOpaqueIdentity(value string) string {
value = strings.TrimSpace(value)
runes := []rune(value)
switch {
case len(runes) == 0:
return ""
case len(runes) <= 4:
return string(runes[0]) + "***"
case len(runes) <= 8:
return string(runes[:2]) + "***" + string(runes[len(runes)-1:])
default:
return string(runes[:3]) + "***" + string(runes[len(runes)-3:])
}
}
func cloneAnyMap(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {