feat: add admin auth identity repair binding
This commit is contained in:
@@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
|||||||
router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
|
router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
|
||||||
router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
|
router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
|
||||||
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
||||||
|
router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
|
||||||
router.POST("/api/v1/admin/users", userHandler.Create)
|
router.POST("/api/v1/admin/users", userHandler.Create)
|
||||||
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
|
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
|
||||||
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
|
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
|
||||||
@@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
|
|||||||
router.ServeHTTP(rec, req)
|
router.ServeHTTP(rec, req)
|
||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
bindBody := map[string]any{
|
||||||
|
"provider_type": "wechat",
|
||||||
|
"provider_key": "wechat-main",
|
||||||
|
"provider_subject": "union-123",
|
||||||
|
"metadata": map[string]any{"source": "admin-repair"},
|
||||||
|
"channel": map[string]any{
|
||||||
|
"channel": "open",
|
||||||
|
"channel_app_id": "wx-open",
|
||||||
|
"channel_subject": "openid-123",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(bindBody)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
|
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
|
||||||
body, _ := json.Marshal(createBody)
|
body, _ = json.Marshal(createBody)
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
|
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
|
|||||||
require.Equal(t, http.StatusOK, rec.Code)
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
|
||||||
|
router, adminSvc := setupAdminRouter()
|
||||||
|
|
||||||
|
body, err := json.Marshal(map[string]any{
|
||||||
|
"provider_type": "oidc",
|
||||||
|
"provider_key": "https://issuer.example",
|
||||||
|
"provider_subject": "subject-123",
|
||||||
|
"issuer": "https://issuer.example",
|
||||||
|
"metadata": map[string]any{"report_id": 12},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
|
||||||
|
require.NotNil(t, adminSvc.boundAuthIdentity)
|
||||||
|
require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
|
||||||
|
require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
|
||||||
|
require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
|
||||||
|
require.Nil(t, adminSvc.boundAuthIdentity.Channel)
|
||||||
|
require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestGroupHandlerEndpoints(t *testing.T) {
|
func TestGroupHandlerEndpoints(t *testing.T) {
|
||||||
router, _ := setupAdminRouter()
|
router, _ := setupAdminRouter()
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ type stubAdminService struct {
|
|||||||
proxyCounts []service.ProxyWithAccountCount
|
proxyCounts []service.ProxyWithAccountCount
|
||||||
redeems []service.RedeemCode
|
redeems []service.RedeemCode
|
||||||
migrationReports []service.AuthIdentityMigrationReport
|
migrationReports []service.AuthIdentityMigrationReport
|
||||||
|
boundAuthIdentity *service.AdminBindAuthIdentityInput
|
||||||
|
boundAuthIdentityFor int64
|
||||||
createdAccounts []*service.CreateAccountInput
|
createdAccounts []*service.CreateAccountInput
|
||||||
createdProxies []*service.CreateProxyInput
|
createdProxies []*service.CreateProxyInput
|
||||||
updatedProxyIDs []int64
|
updatedProxyIDs []int64
|
||||||
@@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con
|
|||||||
return summary, nil
|
return summary, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
|
||||||
|
s.boundAuthIdentityFor = userID
|
||||||
|
copied := input
|
||||||
|
if input.Metadata != nil {
|
||||||
|
copied.Metadata = map[string]any{}
|
||||||
|
for key, value := range input.Metadata {
|
||||||
|
copied.Metadata[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if input.Channel != nil {
|
||||||
|
channel := *input.Channel
|
||||||
|
if input.Channel.Metadata != nil {
|
||||||
|
channel.Metadata = map[string]any{}
|
||||||
|
for key, value := range input.Channel.Metadata {
|
||||||
|
channel.Metadata[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
copied.Channel = &channel
|
||||||
|
}
|
||||||
|
s.boundAuthIdentity = &copied
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
result := &service.AdminBoundAuthIdentity{
|
||||||
|
UserID: userID,
|
||||||
|
ProviderType: input.ProviderType,
|
||||||
|
ProviderKey: input.ProviderKey,
|
||||||
|
ProviderSubject: input.ProviderSubject,
|
||||||
|
VerifiedAt: &now,
|
||||||
|
Issuer: input.Issuer,
|
||||||
|
Metadata: input.Metadata,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
if input.Channel != nil {
|
||||||
|
result.Channel = &service.AdminBoundAuthIdentityChannel{
|
||||||
|
Channel: input.Channel.Channel,
|
||||||
|
ChannelAppID: input.Channel.ChannelAppID,
|
||||||
|
ChannelSubject: input.Channel.ChannelSubject,
|
||||||
|
Metadata: input.Channel.Metadata,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
|
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
|
||||||
return s.groups, int64(len(s.groups)), nil
|
return s.groups, int64(len(s.groups)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,6 +66,22 @@ type UpdateBalanceRequest struct {
|
|||||||
Notes string `json:"notes"`
|
Notes string `json:"notes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BindUserAuthIdentityRequest struct {
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderKey string `json:"provider_key"`
|
||||||
|
ProviderSubject string `json:"provider_subject"`
|
||||||
|
Issuer *string `json:"issuer"`
|
||||||
|
Metadata map[string]any `json:"metadata"`
|
||||||
|
Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BindUserAuthIdentityChannelRequest struct {
|
||||||
|
Channel string `json:"channel"`
|
||||||
|
ChannelAppID string `json:"channel_app_id"`
|
||||||
|
ChannelSubject string `json:"channel_subject"`
|
||||||
|
Metadata map[string]any `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
// List handles listing all users with pagination
|
// List handles listing all users with pagination
|
||||||
// GET /api/v1/admin/users
|
// GET /api/v1/admin/users
|
||||||
// Query params:
|
// Query params:
|
||||||
@@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) {
|
|||||||
response.Paginated(c, reports, total, page, pageSize)
|
response.Paginated(c, reports, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BindAuthIdentity manually binds a canonical auth identity to a user.
|
||||||
|
// POST /api/v1/admin/users/:id/auth-identities
|
||||||
|
func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
|
||||||
|
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "Invalid user ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req BindUserAuthIdentityRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
input := service.AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: req.ProviderType,
|
||||||
|
ProviderKey: req.ProviderKey,
|
||||||
|
ProviderSubject: req.ProviderSubject,
|
||||||
|
Issuer: req.Issuer,
|
||||||
|
Metadata: req.Metadata,
|
||||||
|
}
|
||||||
|
if req.Channel != nil {
|
||||||
|
input.Channel = &service.AdminBindAuthIdentityChannelInput{
|
||||||
|
Channel: req.Channel.Channel,
|
||||||
|
ChannelAppID: req.Channel.ChannelAppID,
|
||||||
|
ChannelSubject: req.Channel.ChannelSubject,
|
||||||
|
Metadata: req.Channel.Metadata,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|
||||||
// Create handles creating a new user
|
// Create handles creating a new user
|
||||||
// POST /api/v1/admin/users
|
// POST /api/v1/admin/users
|
||||||
func (h *UserHandler) Create(c *gin.Context) {
|
func (h *UserHandler) Create(c *gin.Context) {
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
|
users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
|
||||||
users.GET("", h.Admin.User.List)
|
users.GET("", h.Admin.User.List)
|
||||||
users.GET("/:id", h.Admin.User.GetByID)
|
users.GET("/:id", h.Admin.User.GetByID)
|
||||||
|
users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
|
||||||
users.POST("", h.Admin.User.Create)
|
users.POST("", h.Admin.User.Create)
|
||||||
users.PUT("/:id", h.Admin.User.Update)
|
users.PUT("/:id", h.Admin.User.Update)
|
||||||
users.DELETE("/:id", h.Admin.User.Delete)
|
users.DELETE("/:id", h.Admin.User.Delete)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||||
@@ -39,6 +41,7 @@ type AdminService interface {
|
|||||||
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
|
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
|
||||||
ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
|
ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
|
||||||
GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
|
GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
|
||||||
|
BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
|
||||||
|
|
||||||
// Group management
|
// Group management
|
||||||
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
|
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
|
||||||
@@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct {
|
|||||||
ByType map[string]int64 `json:"by_type"`
|
ByType map[string]int64 `json:"by_type"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AdminBindAuthIdentityInput struct {
|
||||||
|
ProviderType string
|
||||||
|
ProviderKey string
|
||||||
|
ProviderSubject string
|
||||||
|
Issuer *string
|
||||||
|
Metadata map[string]any
|
||||||
|
Channel *AdminBindAuthIdentityChannelInput
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdminBindAuthIdentityChannelInput struct {
|
||||||
|
Channel string
|
||||||
|
ChannelAppID string
|
||||||
|
ChannelSubject string
|
||||||
|
Metadata map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdminBoundAuthIdentity struct {
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
ProviderKey string `json:"provider_key"`
|
||||||
|
ProviderSubject string `json:"provider_subject"`
|
||||||
|
VerifiedAt *time.Time `json:"verified_at,omitempty"`
|
||||||
|
Issuer *string `json:"issuer,omitempty"`
|
||||||
|
Metadata map[string]any `json:"metadata"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdminBoundAuthIdentityChannel struct {
|
||||||
|
Channel string `json:"channel"`
|
||||||
|
ChannelAppID string `json:"channel_app_id"`
|
||||||
|
ChannelSubject string `json:"channel_subject"`
|
||||||
|
Metadata map[string]any `json:"metadata"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
type CreateGroupInput struct {
|
type CreateGroupInput struct {
|
||||||
Name string
|
Name string
|
||||||
Description string
|
Description string
|
||||||
@@ -895,6 +936,143 @@ ORDER BY report_type ASC`)
|
|||||||
return summary, nil
|
return summary, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
|
||||||
|
if userID <= 0 {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
|
||||||
|
}
|
||||||
|
if s == nil || s.entClient == nil || s.userRepo == nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
|
||||||
|
}
|
||||||
|
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
|
||||||
|
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||||
|
providerSubject := strings.TrimSpace(input.ProviderSubject)
|
||||||
|
if providerType == "" {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
|
||||||
|
}
|
||||||
|
if providerKey == "" || providerSubject == "" {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
var issuer *string
|
||||||
|
if input.Issuer != nil {
|
||||||
|
trimmed := strings.TrimSpace(*input.Issuer)
|
||||||
|
if trimmed != "" {
|
||||||
|
issuer = &trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelInput := normalizeAdminBindChannelInput(input.Channel)
|
||||||
|
if input.Channel != nil && channelInput == nil {
|
||||||
|
return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
verifiedAt := time.Now().UTC()
|
||||||
|
tx, err := s.entClient.Tx(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
|
identity, err := tx.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ(providerType),
|
||||||
|
authidentity.ProviderKeyEQ(providerKey),
|
||||||
|
authidentity.ProviderSubjectEQ(providerSubject),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
if identity != nil && identity.UserID != userID {
|
||||||
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if identity == nil {
|
||||||
|
create := tx.AuthIdentity.Create().
|
||||||
|
SetUserID(userID).
|
||||||
|
SetProviderType(providerType).
|
||||||
|
SetProviderKey(providerKey).
|
||||||
|
SetProviderSubject(providerSubject).
|
||||||
|
SetVerifiedAt(verifiedAt)
|
||||||
|
if issuer != nil {
|
||||||
|
create = create.SetIssuer(*issuer)
|
||||||
|
}
|
||||||
|
if input.Metadata != nil {
|
||||||
|
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
|
||||||
|
}
|
||||||
|
identity, err = create.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
|
||||||
|
if issuer != nil {
|
||||||
|
update = update.SetIssuer(*issuer)
|
||||||
|
}
|
||||||
|
if input.Metadata != nil {
|
||||||
|
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
|
||||||
|
}
|
||||||
|
identity, err = update.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var channel *dbent.AuthIdentityChannel
|
||||||
|
if channelInput != nil {
|
||||||
|
channel, err = tx.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ(providerType),
|
||||||
|
authidentitychannel.ProviderKeyEQ(providerKey),
|
||||||
|
authidentitychannel.ChannelEQ(channelInput.Channel),
|
||||||
|
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
|
||||||
|
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
|
||||||
|
).
|
||||||
|
WithIdentity().
|
||||||
|
Only(ctx)
|
||||||
|
if err != nil && !dbent.IsNotFound(err) {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
|
||||||
|
}
|
||||||
|
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
|
||||||
|
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
|
||||||
|
}
|
||||||
|
if channel == nil {
|
||||||
|
create := tx.AuthIdentityChannel.Create().
|
||||||
|
SetIdentityID(identity.ID).
|
||||||
|
SetProviderType(providerType).
|
||||||
|
SetProviderKey(providerKey).
|
||||||
|
SetChannel(channelInput.Channel).
|
||||||
|
SetChannelAppID(channelInput.ChannelAppID).
|
||||||
|
SetChannelSubject(channelInput.ChannelSubject)
|
||||||
|
if channelInput.Metadata != nil {
|
||||||
|
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
|
||||||
|
}
|
||||||
|
channel, err = create.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
|
||||||
|
if channelInput.Metadata != nil {
|
||||||
|
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
|
||||||
|
}
|
||||||
|
channel, err = update.Save(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
|
||||||
|
}
|
||||||
|
return buildAdminBoundAuthIdentity(identity, channel), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
|
func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
|
||||||
if s == nil || s.entClient == nil {
|
if s == nil || s.entClient == nil {
|
||||||
return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
|
return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
|
||||||
@@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
|
|||||||
return driver.DB(), nil
|
return driver.DB(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
|
||||||
|
if input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
channel := &AdminBindAuthIdentityChannelInput{
|
||||||
|
Channel: strings.TrimSpace(input.Channel),
|
||||||
|
ChannelAppID: strings.TrimSpace(input.ChannelAppID),
|
||||||
|
ChannelSubject: strings.TrimSpace(input.ChannelSubject),
|
||||||
|
Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
|
||||||
|
}
|
||||||
|
if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return channel
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAdminAuthIdentityProviderType(input string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(input)) {
|
||||||
|
case "email":
|
||||||
|
return "email"
|
||||||
|
case "linuxdo":
|
||||||
|
return "linuxdo"
|
||||||
|
case "oidc":
|
||||||
|
return "oidc"
|
||||||
|
case "wechat":
|
||||||
|
return "wechat"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
|
||||||
|
if identity == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := &AdminBoundAuthIdentity{
|
||||||
|
UserID: identity.UserID,
|
||||||
|
ProviderType: strings.TrimSpace(identity.ProviderType),
|
||||||
|
ProviderKey: strings.TrimSpace(identity.ProviderKey),
|
||||||
|
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
|
||||||
|
VerifiedAt: identity.VerifiedAt,
|
||||||
|
Issuer: identity.Issuer,
|
||||||
|
Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
|
||||||
|
CreatedAt: identity.CreatedAt,
|
||||||
|
UpdatedAt: identity.UpdatedAt,
|
||||||
|
}
|
||||||
|
if channel != nil {
|
||||||
|
result.Channel = &AdminBoundAuthIdentityChannel{
|
||||||
|
Channel: strings.TrimSpace(channel.Channel),
|
||||||
|
ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
|
||||||
|
ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
|
||||||
|
Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
|
||||||
|
CreatedAt: channel.CreatedAt,
|
||||||
|
UpdatedAt: channel.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
|
||||||
|
if input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(input) == 0 {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(input)
|
||||||
|
if err != nil {
|
||||||
|
out := make(map[string]any, len(input))
|
||||||
|
for key, value := range input {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
var out map[string]any
|
||||||
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
|
out = make(map[string]any, len(input))
|
||||||
|
for key, value := range input {
|
||||||
|
out[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
|
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
|
||||||
var (
|
var (
|
||||||
report AuthIdentityMigrationReport
|
report AuthIdentityMigrationReport
|
||||||
|
|||||||
@@ -0,0 +1,215 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
|
||||||
|
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||||
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"entgo.io/ent/dialect"
|
||||||
|
entsql "entgo.io/ent/dialect/sql"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||||
|
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||||
|
t.Cleanup(func() { _ = client.Close() })
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
|
||||||
|
client := newAdminServiceAuthIdentityBindingTestClient(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("bind-target@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
|
||||||
|
entClient: client,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: "wechat",
|
||||||
|
ProviderKey: "wechat-main",
|
||||||
|
ProviderSubject: "union-123",
|
||||||
|
Metadata: map[string]any{"source": "admin-repair"},
|
||||||
|
Channel: &AdminBindAuthIdentityChannelInput{
|
||||||
|
Channel: "open",
|
||||||
|
ChannelAppID: "wx-open",
|
||||||
|
ChannelSubject: "openid-123",
|
||||||
|
Metadata: map[string]any{"scene": "migration"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, user.ID, result.UserID)
|
||||||
|
require.Equal(t, "wechat", result.ProviderType)
|
||||||
|
require.Equal(t, "wechat-main", result.ProviderKey)
|
||||||
|
require.NotNil(t, result.VerifiedAt)
|
||||||
|
require.NotNil(t, result.Channel)
|
||||||
|
require.Equal(t, "open", result.Channel.Channel)
|
||||||
|
|
||||||
|
identity, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("wechat"),
|
||||||
|
authidentity.ProviderKeyEQ("wechat-main"),
|
||||||
|
authidentity.ProviderSubjectEQ("union-123"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, user.ID, identity.UserID)
|
||||||
|
require.Equal(t, "admin-repair", identity.Metadata["source"])
|
||||||
|
require.NotNil(t, identity.VerifiedAt)
|
||||||
|
|
||||||
|
channel, err := client.AuthIdentityChannel.Query().
|
||||||
|
Where(
|
||||||
|
authidentitychannel.ProviderTypeEQ("wechat"),
|
||||||
|
authidentitychannel.ProviderKeyEQ("wechat-main"),
|
||||||
|
authidentitychannel.ChannelEQ("open"),
|
||||||
|
authidentitychannel.ChannelAppIDEQ("wx-open"),
|
||||||
|
authidentitychannel.ChannelSubjectEQ("openid-123"),
|
||||||
|
).
|
||||||
|
Only(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, identity.ID, channel.IdentityID)
|
||||||
|
require.Equal(t, "migration", channel.Metadata["scene"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
|
||||||
|
client := newAdminServiceAuthIdentityBindingTestClient(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
owner, err := client.User.Create().
|
||||||
|
SetEmail("owner@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
target, err := client.User.Create().
|
||||||
|
SetEmail("target@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = client.AuthIdentity.Create().
|
||||||
|
SetUserID(owner.ID).
|
||||||
|
SetProviderType("oidc").
|
||||||
|
SetProviderKey("https://issuer.example").
|
||||||
|
SetProviderSubject("subject-1").
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
|
||||||
|
entClient: client,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: "oidc",
|
||||||
|
ProviderKey: "https://issuer.example",
|
||||||
|
ProviderSubject: "subject-1",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
|
||||||
|
client := newAdminServiceAuthIdentityBindingTestClient(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("same-user@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
|
||||||
|
entClient: client,
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: "oidc",
|
||||||
|
ProviderKey: "https://issuer.example",
|
||||||
|
ProviderSubject: "subject-2",
|
||||||
|
Metadata: map[string]any{"source": "first"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: "oidc",
|
||||||
|
ProviderKey: "https://issuer.example",
|
||||||
|
ProviderSubject: "subject-2",
|
||||||
|
Metadata: map[string]any{"source": "second"},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, first.UserID, second.UserID)
|
||||||
|
require.Equal(t, "second", second.Metadata["source"])
|
||||||
|
|
||||||
|
identities, err := client.AuthIdentity.Query().
|
||||||
|
Where(
|
||||||
|
authidentity.ProviderTypeEQ("oidc"),
|
||||||
|
authidentity.ProviderKeyEQ("https://issuer.example"),
|
||||||
|
authidentity.ProviderSubjectEQ("subject-2"),
|
||||||
|
).
|
||||||
|
All(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, identities, 1)
|
||||||
|
require.Equal(t, "second", identities[0].Metadata["source"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
|
||||||
|
client := newAdminServiceAuthIdentityBindingTestClient(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := client.User.Create().
|
||||||
|
SetEmail("invalid-provider@example.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetRole(RoleUser).
|
||||||
|
SetStatus(StatusActive).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
svc := &adminServiceImpl{
|
||||||
|
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
|
||||||
|
entClient: client,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
|
||||||
|
ProviderType: "github",
|
||||||
|
ProviderKey: "github-main",
|
||||||
|
ProviderSubject: "subject-3",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user