feat: add admin auth identity repair binding

This commit is contained in:
IanShaw027
2026-04-20 22:22:14 +08:00
parent 3bd3027251
commit 452e55a53c
6 changed files with 628 additions and 1 deletions

View File

@@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary)
router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID) router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create) router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
@@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req) router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
bindBody := map[string]any{
"provider_type": "wechat",
"provider_key": "wechat-main",
"provider_subject": "union-123",
"metadata": map[string]any{"source": "admin-repair"},
"channel": map[string]any{
"channel": "open",
"channel_app_id": "wx-open",
"channel_subject": "openid-123",
},
}
body, _ := json.Marshal(bindBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
body, _ := json.Marshal(createBody) body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, http.StatusOK, rec.Code)
} }
func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
router, adminSvc := setupAdminRouter()
body, err := json.Marshal(map[string]any{
"provider_type": "oidc",
"provider_key": "https://issuer.example",
"provider_subject": "subject-123",
"issuer": "https://issuer.example",
"metadata": map[string]any{"report_id": 12},
})
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
require.NotNil(t, adminSvc.boundAuthIdentity)
require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
require.Nil(t, adminSvc.boundAuthIdentity.Channel)
require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
}
func TestGroupHandlerEndpoints(t *testing.T) { func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter() router, _ := setupAdminRouter()

View File

@@ -18,6 +18,8 @@ type stubAdminService struct {
proxyCounts []service.ProxyWithAccountCount proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode redeems []service.RedeemCode
migrationReports []service.AuthIdentityMigrationReport migrationReports []service.AuthIdentityMigrationReport
boundAuthIdentity *service.AdminBindAuthIdentityInput
boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64 updatedProxyIDs []int64
@@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con
return summary, nil return summary, nil
} }
func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
s.boundAuthIdentityFor = userID
copied := input
if input.Metadata != nil {
copied.Metadata = map[string]any{}
for key, value := range input.Metadata {
copied.Metadata[key] = value
}
}
if input.Channel != nil {
channel := *input.Channel
if input.Channel.Metadata != nil {
channel.Metadata = map[string]any{}
for key, value := range input.Channel.Metadata {
channel.Metadata[key] = value
}
}
copied.Channel = &channel
}
s.boundAuthIdentity = &copied
now := time.Now().UTC()
result := &service.AdminBoundAuthIdentity{
UserID: userID,
ProviderType: input.ProviderType,
ProviderKey: input.ProviderKey,
ProviderSubject: input.ProviderSubject,
VerifiedAt: &now,
Issuer: input.Issuer,
Metadata: input.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
if input.Channel != nil {
result.Channel = &service.AdminBoundAuthIdentityChannel{
Channel: input.Channel.Channel,
ChannelAppID: input.Channel.ChannelAppID,
ChannelSubject: input.Channel.ChannelSubject,
Metadata: input.Channel.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
}
return result, nil
}
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil return s.groups, int64(len(s.groups)), nil
} }

View File

@@ -66,6 +66,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"` Notes string `json:"notes"`
} }
type BindUserAuthIdentityRequest struct {
ProviderType string `json:"provider_type"`
ProviderKey string `json:"provider_key"`
ProviderSubject string `json:"provider_subject"`
Issuer *string `json:"issuer"`
Metadata map[string]any `json:"metadata"`
Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
}
type BindUserAuthIdentityChannelRequest struct {
Channel string `json:"channel"`
ChannelAppID string `json:"channel_app_id"`
ChannelSubject string `json:"channel_subject"`
Metadata map[string]any `json:"metadata"`
}
// List handles listing all users with pagination // List handles listing all users with pagination
// GET /api/v1/admin/users // GET /api/v1/admin/users
// Query params: // Query params:
@@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) {
response.Paginated(c, reports, total, page, pageSize) response.Paginated(c, reports, total, page, pageSize)
} }
// BindAuthIdentity manually binds a canonical auth identity to a user.
// POST /api/v1/admin/users/:id/auth-identities
func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req BindUserAuthIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.AdminBindAuthIdentityInput{
ProviderType: req.ProviderType,
ProviderKey: req.ProviderKey,
ProviderSubject: req.ProviderSubject,
Issuer: req.Issuer,
Metadata: req.Metadata,
}
if req.Channel != nil {
input.Channel = &service.AdminBindAuthIdentityChannelInput{
Channel: req.Channel.Channel,
ChannelAppID: req.Channel.ChannelAppID,
ChannelSubject: req.Channel.ChannelSubject,
Metadata: req.Channel.Metadata,
}
}
result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Create handles creating a new user // Create handles creating a new user
// POST /api/v1/admin/users // POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) { func (h *UserHandler) Create(c *gin.Context) {

View File

@@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports)
users.GET("", h.Admin.User.List) users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID) users.GET("/:id", h.Admin.User.GetByID)
users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create) users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update) users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete) users.DELETE("/:id", h.Admin.User.Delete)

View File

@@ -13,6 +13,8 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -39,6 +41,7 @@ type AdminService interface {
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error)
GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error)
BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management // Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
@@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct {
ByType map[string]int64 `json:"by_type"` ByType map[string]int64 `json:"by_type"`
} }
type AdminBindAuthIdentityInput struct {
ProviderType string
ProviderKey string
ProviderSubject string
Issuer *string
Metadata map[string]any
Channel *AdminBindAuthIdentityChannelInput
}
type AdminBindAuthIdentityChannelInput struct {
Channel string
ChannelAppID string
ChannelSubject string
Metadata map[string]any
}
type AdminBoundAuthIdentity struct {
UserID int64 `json:"user_id"`
ProviderType string `json:"provider_type"`
ProviderKey string `json:"provider_key"`
ProviderSubject string `json:"provider_subject"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
Issuer *string `json:"issuer,omitempty"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
}
type AdminBoundAuthIdentityChannel struct {
Channel string `json:"channel"`
ChannelAppID string `json:"channel_app_id"`
ChannelSubject string `json:"channel_subject"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type CreateGroupInput struct { type CreateGroupInput struct {
Name string Name string
Description string Description string
@@ -895,6 +936,143 @@ ORDER BY report_type ASC`)
return summary, nil return summary, nil
} }
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
}
if s == nil || s.entClient == nil || s.userRepo == nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
}
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
return nil, err
}
providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerType == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
}
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
}
var issuer *string
if input.Issuer != nil {
trimmed := strings.TrimSpace(*input.Issuer)
if trimmed != "" {
issuer = &trimmed
}
}
channelInput := normalizeAdminBindChannelInput(input.Channel)
if input.Channel != nil && channelInput == nil {
return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
}
verifiedAt := time.Now().UTC()
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
}
defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil && identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
if identity == nil {
create := tx.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt)
if issuer != nil {
create = create.SetIssuer(*issuer)
}
if input.Metadata != nil {
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
}
identity, err = create.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
} else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
if issuer != nil {
update = update.SetIssuer(*issuer)
}
if input.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
}
identity, err = update.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
}
var channel *dbent.AuthIdentityChannel
if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey),
authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
).
WithIdentity().
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
if channel == nil {
create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject)
if channelInput.Metadata != nil {
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
channel, err = create.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
} else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
channel, err = update.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
}
}
if err := tx.Commit(); err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
}
return buildAdminBoundAuthIdentity(identity, channel), nil
}
func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
if s == nil || s.entClient == nil { if s == nil || s.entClient == nil {
return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready")
@@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) {
return driver.DB(), nil return driver.DB(), nil
} }
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil {
return nil
}
channel := &AdminBindAuthIdentityChannelInput{
Channel: strings.TrimSpace(input.Channel),
ChannelAppID: strings.TrimSpace(input.ChannelAppID),
ChannelSubject: strings.TrimSpace(input.ChannelSubject),
Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
}
if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
return nil
}
return channel
}
func normalizeAdminAuthIdentityProviderType(input string) string {
switch strings.ToLower(strings.TrimSpace(input)) {
case "email":
return "email"
case "linuxdo":
return "linuxdo"
case "oidc":
return "oidc"
case "wechat":
return "wechat"
default:
return ""
}
}
func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
if identity == nil {
return nil
}
result := &AdminBoundAuthIdentity{
UserID: identity.UserID,
ProviderType: strings.TrimSpace(identity.ProviderType),
ProviderKey: strings.TrimSpace(identity.ProviderKey),
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
VerifiedAt: identity.VerifiedAt,
Issuer: identity.Issuer,
Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
CreatedAt: identity.CreatedAt,
UpdatedAt: identity.UpdatedAt,
}
if channel != nil {
result.Channel = &AdminBoundAuthIdentityChannel{
Channel: strings.TrimSpace(channel.Channel),
ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
CreatedAt: channel.CreatedAt,
UpdatedAt: channel.UpdatedAt,
}
}
return result
}
func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
if input == nil {
return nil
}
if len(input) == 0 {
return map[string]any{}
}
data, err := json.Marshal(input)
if err != nil {
out := make(map[string]any, len(input))
for key, value := range input {
out[key] = value
}
return out
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
out = make(map[string]any, len(input))
for key, value := range input {
out[key] = value
}
}
return out
}
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
var ( var (
report AuthIdentityMigrationReport report AuthIdentityMigrationReport

View File

@@ -0,0 +1,215 @@
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("bind-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-123",
Metadata: map[string]any{"source": "admin-repair"},
Channel: &AdminBindAuthIdentityChannelInput{
Channel: "open",
ChannelAppID: "wx-open",
ChannelSubject: "openid-123",
Metadata: map[string]any{"scene": "migration"},
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.Equal(t, "wechat", result.ProviderType)
require.Equal(t, "wechat-main", result.ProviderKey)
require.NotNil(t, result.VerifiedAt)
require.NotNil(t, result.Channel)
require.Equal(t, "open", result.Channel.Channel)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.Equal(t, "admin-repair", identity.Metadata["source"])
require.NotNil(t, identity.VerifiedAt)
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "migration", channel.Metadata["scene"])
}
func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
owner, err := client.User.Create().
SetEmail("owner@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
target, err := client.User.Create().
SetEmail("target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(owner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("subject-1").
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
entClient: client,
}
_, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
})
require.Error(t, err)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
}
func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("same-user@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-2",
Metadata: map[string]any{"source": "first"},
})
require.NoError(t, err)
second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-2",
Metadata: map[string]any{"source": "second"},
})
require.NoError(t, err)
require.Equal(t, first.UserID, second.UserID)
require.Equal(t, "second", second.Metadata["source"])
identities, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("subject-2"),
).
All(ctx)
require.NoError(t, err)
require.Len(t, identities, 1)
require.Equal(t, "second", identities[0].Metadata["source"])
}
func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("invalid-provider@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
_, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "github",
ProviderKey: "github-main",
ProviderSubject: "subject-3",
})
require.Error(t, err)
require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
}