fix(oauth): enhance error handling and transaction management for OAuth user creation and binding
- Improve error handling in DeleteCustomOAuthProvider to log and return errors when fetching binding counts. - Refactor user creation and OAuth binding logic to use transactions for atomic operations, ensuring data integrity. - Add unique constraints to UserOAuthBinding model to prevent duplicate bindings. - Enhance GitHub OAuth provider error logging for non-200 responses. - Update AccountManagement component to provide clearer error messages on API failures.
This commit is contained in:
@@ -296,7 +296,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if there are any user bindings
|
// Check if there are any user bindings
|
||||||
count, _ := model.GetBindingCountByProviderId(id)
|
count, err := model.GetBindingCountByProviderId(id)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
|
||||||
|
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
|
||||||
|
return
|
||||||
|
}
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/oauth"
|
"github.com/QuantumNous/new-api/oauth"
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// providerParams returns map with Provider key for i18n templates
|
// providerParams returns map with Provider key for i18n templates
|
||||||
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
|||||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Insert(inviterId); err != nil {
|
// Use transaction to ensure user creation and OAuth binding are atomic
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For custom providers, create the binding after user is created
|
|
||||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||||
binding := &model.UserOAuthBinding{
|
// Custom provider: create user and binding in a transaction
|
||||||
UserId: user.Id,
|
err := model.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
ProviderId: genericProvider.GetProviderId(),
|
// Create user
|
||||||
ProviderUserId: oauthUser.ProviderUserID,
|
if err := user.InsertWithTx(tx, inviterId); err != nil {
|
||||||
}
|
return err
|
||||||
if err := model.CreateUserOAuthBinding(binding); err != nil {
|
}
|
||||||
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
|
|
||||||
// Don't fail the registration, just log the error
|
// Create OAuth binding
|
||||||
|
binding := &model.UserOAuthBinding{
|
||||||
|
UserId: user.Id,
|
||||||
|
ProviderId: genericProvider.GetProviderId(),
|
||||||
|
ProviderUserId: oauthUser.ProviderUserID,
|
||||||
|
}
|
||||||
|
if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
|
||||||
|
user.FinalizeOAuthUserCreation(inviterId)
|
||||||
} else {
|
} else {
|
||||||
// Built-in provider: set the provider user ID on the user model
|
// Built-in provider: create user and update provider ID in a transaction
|
||||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
err := model.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
if err := user.Update(false); err != nil {
|
// Create user
|
||||||
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
|
if err := user.InsertWithTx(tx, inviterId); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the provider user ID on the user model and update
|
||||||
|
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||||
|
if err := tx.Model(user).Updates(map[string]interface{}{
|
||||||
|
"github_id": user.GitHubId,
|
||||||
|
"discord_id": user.DiscordId,
|
||||||
|
"oidc_id": user.OidcId,
|
||||||
|
"linux_do_id": user.LinuxDOId,
|
||||||
|
"wechat_id": user.WeChatId,
|
||||||
|
"telegram_id": user.TelegramId,
|
||||||
|
}).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Perform post-transaction tasks
|
||||||
|
user.FinalizeOAuthUserCreation(inviterId)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
|
|||||||
@@ -97,13 +97,18 @@ func DeleteCustomOAuthProvider(id int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsSlugTaken checks if a slug is already taken by another provider
|
// IsSlugTaken checks if a slug is already taken by another provider
|
||||||
|
// Returns true on DB errors (fail-closed) to prevent slug conflicts
|
||||||
func IsSlugTaken(slug string, excludeId int) bool {
|
func IsSlugTaken(slug string, excludeId int) bool {
|
||||||
var count int64
|
var count int64
|
||||||
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
|
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
|
||||||
if excludeId > 0 {
|
if excludeId > 0 {
|
||||||
query = query.Where("id != ?", excludeId)
|
query = query.Where("id != ?", excludeId)
|
||||||
}
|
}
|
||||||
query.Count(&count)
|
res := query.Count(&count)
|
||||||
|
if res.Error != nil {
|
||||||
|
// Fail-closed: treat DB errors as slug being taken to prevent conflicts
|
||||||
|
return true
|
||||||
|
}
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -429,6 +429,65 @@ func (user *User) Insert(inviterId int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertWithTx inserts a new user within an existing transaction.
|
||||||
|
// This is used for OAuth registration where user creation and binding need to be atomic.
|
||||||
|
// Post-creation tasks (sidebar config, logs, inviter rewards) are handled after the transaction commits.
|
||||||
|
func (user *User) InsertWithTx(tx *gorm.DB, inviterId int) error {
|
||||||
|
var err error
|
||||||
|
if user.Password != "" {
|
||||||
|
user.Password, err = common.Password2Hash(user.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user.Quota = common.QuotaForNewUser
|
||||||
|
user.AffCode = common.GetRandomString(4)
|
||||||
|
|
||||||
|
// 初始化用户设置
|
||||||
|
if user.Setting == "" {
|
||||||
|
defaultSetting := dto.UserSetting{}
|
||||||
|
user.SetSetting(defaultSetting)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tx.Create(user)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinalizeOAuthUserCreation performs post-transaction tasks for OAuth user creation.
|
||||||
|
// This should be called after the transaction commits successfully.
|
||||||
|
func (user *User) FinalizeOAuthUserCreation(inviterId int) {
|
||||||
|
// 用户创建成功后,根据角色初始化边栏配置
|
||||||
|
var createdUser User
|
||||||
|
if err := DB.Where("id = ?", user.Id).First(&createdUser).Error; err == nil {
|
||||||
|
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
|
||||||
|
if defaultSidebarConfig != "" {
|
||||||
|
currentSetting := createdUser.GetSetting()
|
||||||
|
currentSetting.SidebarModules = defaultSidebarConfig
|
||||||
|
createdUser.SetSetting(currentSetting)
|
||||||
|
createdUser.Update(false)
|
||||||
|
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.QuotaForNewUser > 0 {
|
||||||
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||||
|
}
|
||||||
|
if inviterId != 0 {
|
||||||
|
if common.QuotaForInvitee > 0 {
|
||||||
|
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
|
||||||
|
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
|
||||||
|
}
|
||||||
|
if common.QuotaForInviter > 0 {
|
||||||
|
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
|
||||||
|
_ = inviteUser(inviterId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) Update(updatePassword bool) error {
|
func (user *User) Update(updatePassword bool) error {
|
||||||
var err error
|
var err error
|
||||||
if updatePassword {
|
if updatePassword {
|
||||||
|
|||||||
@@ -3,18 +3,17 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
|
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
|
||||||
type UserOAuthBinding struct {
|
type UserOAuthBinding struct {
|
||||||
Id int `json:"id" gorm:"primaryKey"`
|
Id int `json:"id" gorm:"primaryKey"`
|
||||||
UserId int `json:"user_id" gorm:"index;not null"` // User ID
|
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
|
||||||
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
|
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
|
||||||
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
|
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
// Composite unique index to prevent duplicate bindings
|
|
||||||
// One OAuth account can only be bound to one user
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (UserOAuthBinding) TableName() string {
|
func (UserOAuthBinding) TableName() string {
|
||||||
@@ -82,6 +81,29 @@ func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
|
|||||||
return DB.Create(binding).Error
|
return DB.Create(binding).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
|
||||||
|
func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
|
||||||
|
if binding.UserId == 0 {
|
||||||
|
return errors.New("user ID is required")
|
||||||
|
}
|
||||||
|
if binding.ProviderId == 0 {
|
||||||
|
return errors.New("provider ID is required")
|
||||||
|
}
|
||||||
|
if binding.ProviderUserId == "" {
|
||||||
|
return errors.New("provider user ID is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this provider user ID is already taken (use tx to check within the same transaction)
|
||||||
|
var count int64
|
||||||
|
tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
|
||||||
|
if count > 0 {
|
||||||
|
return errors.New("this OAuth account is already bound to another user")
|
||||||
|
}
|
||||||
|
|
||||||
|
binding.CreatedAt = time.Now()
|
||||||
|
return tx.Create(binding).Error
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
|
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
|
||||||
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
|
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
|
||||||
// Check if the new provider user ID is already taken by another user
|
// Check if the new provider user ID is already taken by another user
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -122,6 +123,17 @@ func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*O
|
|||||||
|
|
||||||
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
|
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
|
||||||
|
|
||||||
|
// Check for non-200 status codes before attempting to decode
|
||||||
|
if res.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(res.Body)
|
||||||
|
bodyStr := string(body)
|
||||||
|
if len(bodyStr) > 500 {
|
||||||
|
bodyStr = bodyStr[:500] + "..."
|
||||||
|
}
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo failed: status=%d, body=%s", res.StatusCode, bodyStr))
|
||||||
|
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthGetUserErr, map[string]any{"Provider": "GitHub"}, fmt.Sprintf("status %d", res.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
var githubUser gitHubUser
|
var githubUser gitHubUser
|
||||||
err = json.NewDecoder(res.Body).Decode(&githubUser)
|
err = json.NewDecoder(res.Body).Decode(&githubUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -107,9 +107,11 @@ const AccountManagement = ({
|
|||||||
const res = await API.get('/api/user/oauth/bindings');
|
const res = await API.get('/api/user/oauth/bindings');
|
||||||
if (res.data.success) {
|
if (res.data.success) {
|
||||||
setCustomOAuthBindings(res.data.data || []);
|
setCustomOAuthBindings(res.data.data || []);
|
||||||
|
} else {
|
||||||
|
showError(res.data.message || t('获取绑定信息失败'));
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// ignore
|
showError(error.response?.data?.message || error.message || t('获取绑定信息失败'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -131,7 +133,7 @@ const AccountManagement = ({
|
|||||||
showError(res.data.message);
|
showError(res.data.message);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showError(t('操作失败'));
|
showError(error.response?.data?.message || error.message || t('操作失败'));
|
||||||
} finally {
|
} finally {
|
||||||
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
|
setCustomOAuthLoading((prev) => ({ ...prev, [providerId]: false }));
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user