diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index ff9eec7e..57620005 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) @@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + bindBody := map[string]any{ + "provider_type": "wechat", + "provider_key": "wechat-main", + "provider_subject": "union-123", + "metadata": map[string]any{"source": "admin-repair"}, + "channel": map[string]any{ + "channel": "open", + "channel_app_id": "wx-open", + "channel_subject": "openid-123", + }, + } + body, _ := json.Marshal(bindBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} - body, _ := json.Marshal(createBody) + body, _ = json.Marshal(createBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) } +func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) { + router, adminSvc := setupAdminRouter() + + body, err := json.Marshal(map[string]any{ + "provider_type": "oidc", + "provider_key": "https://issuer.example", + "provider_subject": "subject-123", + "issuer": "https://issuer.example", + "metadata": map[string]any{"report_id": 12}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor) + require.NotNil(t, adminSvc.boundAuthIdentity) + require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType) + require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey) + require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject) + require.Nil(t, adminSvc.boundAuthIdentity.Channel) + require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"]) +} + func TestGroupHandlerEndpoints(t *testing.T) { router, _ := setupAdminRouter() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 681c25c6..c8c7a247 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -18,6 +18,8 @@ type stubAdminService struct { proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode migrationReports []service.AuthIdentityMigrationReport + boundAuthIdentity *service.AdminBindAuthIdentityInput + boundAuthIdentityFor int64 createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con return summary, nil } +func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { + s.boundAuthIdentityFor = userID + copied := input + if input.Metadata != nil { + copied.Metadata = map[string]any{} + for key, value := range input.Metadata { + copied.Metadata[key] = value + } + } + if input.Channel != nil { + channel := *input.Channel + if input.Channel.Metadata != nil { + channel.Metadata = map[string]any{} + for key, value := range input.Channel.Metadata { + channel.Metadata[key] = value + } + } + copied.Channel = &channel + } + s.boundAuthIdentity = &copied + + now := time.Now().UTC() + result := &service.AdminBoundAuthIdentity{ + UserID: userID, + ProviderType: input.ProviderType, + ProviderKey: input.ProviderKey, + ProviderSubject: input.ProviderSubject, + VerifiedAt: &now, + Issuer: input.Issuer, + Metadata: input.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + if input.Channel != nil { + result.Channel = &service.AdminBoundAuthIdentityChannel{ + Channel: input.Channel.Channel, + ChannelAppID: input.Channel.ChannelAppID, + ChannelSubject: input.Channel.ChannelSubject, + Metadata: input.Channel.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + } + return result, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index ee3fbb1e..321214af 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -66,6 +66,22 @@ type UpdateBalanceRequest struct { Notes string `json:"notes"` } +type BindUserAuthIdentityRequest struct { + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + Issuer *string `json:"issuer"` + Metadata map[string]any `json:"metadata"` + Channel *BindUserAuthIdentityChannelRequest `json:"channel"` +} + +type BindUserAuthIdentityChannelRequest struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) { response.Paginated(c, reports, total, page, pageSize) } +// BindAuthIdentity manually binds a canonical auth identity to a user. +// POST /api/v1/admin/users/:id/auth-identities +func (h *UserHandler) BindAuthIdentity(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req BindUserAuthIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := service.AdminBindAuthIdentityInput{ + ProviderType: req.ProviderType, + ProviderKey: req.ProviderKey, + ProviderSubject: req.ProviderSubject, + Issuer: req.Issuer, + Metadata: req.Metadata, + } + if req.Channel != nil { + input.Channel = &service.AdminBindAuthIdentityChannelInput{ + Channel: req.Channel.Channel, + ChannelAppID: req.Channel.ChannelAppID, + ChannelSubject: req.Channel.ChannelSubject, + Metadata: req.Channel.Metadata, + } + } + + result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0b5aaf09..c78fba33 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) + users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) users.POST("", h.Admin.User.Create) users.PUT("/:id", h.Admin.User.Update) users.DELETE("/:id", h.Admin.User.Delete) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 972681a5..9ff26861 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -13,6 +13,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -39,6 +41,7 @@ type AdminService interface { GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) + BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct { ByType map[string]int64 `json:"by_type"` } +type AdminBindAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Issuer *string + Metadata map[string]any + Channel *AdminBindAuthIdentityChannelInput +} + +type AdminBindAuthIdentityChannelInput struct { + Channel string + ChannelAppID string + ChannelSubject string + Metadata map[string]any +} + +type AdminBoundAuthIdentity struct { + UserID int64 `json:"user_id"` + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"` +} + +type AdminBoundAuthIdentityChannel struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateGroupInput struct { Name string Description string @@ -895,6 +936,143 @@ ORDER BY report_type ASC`) return summary, nil } +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identity, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channel, err = tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} + func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { if s == nil || s.entClient == nil { return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") @@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { return driver.DB(), nil } +func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { + if input == nil { + return nil + } + channel := &AdminBindAuthIdentityChannelInput{ + Channel: strings.TrimSpace(input.Channel), + ChannelAppID: strings.TrimSpace(input.ChannelAppID), + ChannelSubject: strings.TrimSpace(input.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(input.Metadata), + } + if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" { + return nil + } + return channel +} + +func normalizeAdminAuthIdentityProviderType(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "email": + return "email" + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + default: + return "" + } +} + +func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity { + if identity == nil { + return nil + } + result := &AdminBoundAuthIdentity{ + UserID: identity.UserID, + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + } + if channel != nil { + result.Channel = &AdminBoundAuthIdentityChannel{ + Channel: strings.TrimSpace(channel.Channel), + ChannelAppID: strings.TrimSpace(channel.ChannelAppID), + ChannelSubject: strings.TrimSpace(channel.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata), + CreatedAt: channel.CreatedAt, + UpdatedAt: channel.UpdatedAt, + } + } + return result +} + +func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { + if input == nil { + return nil + } + if len(input) == 0 { + return map[string]any{} + } + data, err := json.Marshal(input) + if err != nil { + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + out = make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + } + return out +} + func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { var ( report AuthIdentityMigrationReport diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go new file mode 100644 index 00000000..f8ce3935 --- /dev/null +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -0,0 +1,215 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-123", + Metadata: map[string]any{"scene": "migration"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "wechat", result.ProviderType) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.VerifiedAt) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + require.NotNil(t, identity.VerifiedAt) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "migration", channel.Metadata["scene"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + target, err := client.User.Create(). + SetEmail("target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("subject-1"). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }) + require.Error(t, err) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err)) +} + +func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("same-user@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "first"}, + }) + require.NoError(t, err) + + second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "second"}, + }) + require.NoError(t, err) + require.Equal(t, first.UserID, second.UserID) + require.Equal(t, "second", second.Metadata["source"]) + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("subject-2"), + ). + All(ctx) + require.NoError(t, err) + require.Len(t, identities, 1) + require.Equal(t, "second", identities[0].Metadata["source"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("invalid-provider@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "github", + ProviderKey: "github-main", + ProviderSubject: "subject-3", + }) + require.Error(t, err) + require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err)) +}