Merge pull request #2857 from QuantumNous/feat/custom-oauth

feat(oauth): implement custom OAuth provider
This commit is contained in:
Calcium-Ion
2026-02-08 00:13:20 +08:00
committed by GitHub
7 changed files with 192 additions and 45 deletions

View File

@@ -166,21 +166,21 @@ func CreateCustomOAuthProvider(c *gin.Context) {
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
Name string `json:"name"`
Slug string `json:"slug"`
Enabled *bool `json:"enabled"` // Optional: if nil, keep existing
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown *string `json:"well_known"` // Optional: if nil, keep existing
AuthStyle *int `json:"auth_style"` // Optional: if nil, keep existing
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
@@ -227,7 +227,9 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.Slug != "" {
provider.Slug = req.Slug
}
provider.Enabled = req.Enabled
if req.Enabled != nil {
provider.Enabled = *req.Enabled
}
if req.ClientId != "" {
provider.ClientId = req.ClientId
}
@@ -258,8 +260,12 @@ func UpdateCustomOAuthProvider(c *gin.Context) {
if req.EmailField != "" {
provider.EmailField = req.EmailField
}
provider.WellKnown = req.WellKnown
provider.AuthStyle = req.AuthStyle
if req.WellKnown != nil {
provider.WellKnown = *req.WellKnown
}
if req.AuthStyle != nil {
provider.AuthStyle = *req.AuthStyle
}
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
@@ -296,7 +302,12 @@ func DeleteCustomOAuthProvider(c *gin.Context) {
}
// Check if there are any user bindings
count, _ := model.GetBindingCountByProviderId(id)
count, err := model.GetBindingCountByProviderId(id)
if err != nil {
common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error())
common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试")
return
}
if count > 0 {
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
return

View File

@@ -11,6 +11,7 @@ import (
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// providerParams returns map with Provider key for i18n templates
@@ -256,27 +257,62 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
return nil, err
}
// For custom providers, create the binding after user is created
// Use transaction to ensure user creation and OAuth binding are atomic
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBinding(binding); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
// Don't fail the registration, just log the error
// Custom provider: create user and binding in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Create OAuth binding
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBindingWithTx(tx, binding); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks (logs, sidebar config, inviter rewards)
user.FinalizeOAuthUserCreation(inviterId)
} else {
// Built-in provider: set the provider user ID on the user model
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := user.Update(false); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
// Built-in provider: create user and update provider ID in a transaction
err := model.DB.Transaction(func(tx *gorm.DB) error {
// Create user
if err := user.InsertWithTx(tx, inviterId); err != nil {
return err
}
// Set the provider user ID on the user model and update
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := tx.Model(user).Updates(map[string]interface{}{
"github_id": user.GitHubId,
"discord_id": user.DiscordId,
"oidc_id": user.OidcId,
"linux_do_id": user.LinuxDOId,
"wechat_id": user.WeChatId,
"telegram_id": user.TelegramId,
}).Error; err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// Perform post-transaction tasks
user.FinalizeOAuthUserCreation(inviterId)
}
return user, nil