feat(oauth): implement custom OAuth provider management #1106
- Add support for custom OAuth providers, including creation, retrieval, updating, and deletion. - Introduce new model and controller for managing custom OAuth providers. - Enhance existing OAuth logic to accommodate custom providers. - Update API routes for custom OAuth provider management. - Include i18n support for custom OAuth-related messages.
This commit is contained in:
386
controller/custom_oauth.go
Normal file
386
controller/custom_oauth.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CustomOAuthProviderResponse is the response structure for custom OAuth providers
|
||||
// It excludes sensitive fields like client_secret
|
||||
type CustomOAuthProviderResponse struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
|
||||
return &CustomOAuthProviderResponse{
|
||||
Id: p.Id,
|
||||
Name: p.Name,
|
||||
Slug: p.Slug,
|
||||
Enabled: p.Enabled,
|
||||
ClientId: p.ClientId,
|
||||
AuthorizationEndpoint: p.AuthorizationEndpoint,
|
||||
TokenEndpoint: p.TokenEndpoint,
|
||||
UserInfoEndpoint: p.UserInfoEndpoint,
|
||||
Scopes: p.Scopes,
|
||||
UserIdField: p.UserIdField,
|
||||
UsernameField: p.UsernameField,
|
||||
DisplayNameField: p.DisplayNameField,
|
||||
EmailField: p.EmailField,
|
||||
WellKnown: p.WellKnown,
|
||||
AuthStyle: p.AuthStyle,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCustomOAuthProviders returns all custom OAuth providers
|
||||
func GetCustomOAuthProviders(c *gin.Context) {
|
||||
providers, err := model.GetAllCustomOAuthProviders()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]*CustomOAuthProviderResponse, len(providers))
|
||||
for i, p := range providers {
|
||||
response[i] = toCustomOAuthProviderResponse(p)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCustomOAuthProvider returns a single custom OAuth provider by ID
|
||||
func GetCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
|
||||
type CreateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Slug string `json:"slug" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id" binding:"required"`
|
||||
ClientSecret string `json:"client_secret" binding:"required"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
|
||||
TokenEndpoint string `json:"token_endpoint" binding:"required"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
// CreateCustomOAuthProvider creates a new custom OAuth provider
|
||||
func CreateCustomOAuthProvider(c *gin.Context) {
|
||||
var req CreateCustomOAuthProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if slug is already taken
|
||||
if model.IsSlugTaken(req.Slug, 0) {
|
||||
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if slug conflicts with built-in providers
|
||||
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
||||
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
||||
return
|
||||
}
|
||||
|
||||
provider := &model.CustomOAuthProvider{
|
||||
Name: req.Name,
|
||||
Slug: req.Slug,
|
||||
Enabled: req.Enabled,
|
||||
ClientId: req.ClientId,
|
||||
ClientSecret: req.ClientSecret,
|
||||
AuthorizationEndpoint: req.AuthorizationEndpoint,
|
||||
TokenEndpoint: req.TokenEndpoint,
|
||||
UserInfoEndpoint: req.UserInfoEndpoint,
|
||||
Scopes: req.Scopes,
|
||||
UserIdField: req.UserIdField,
|
||||
UsernameField: req.UsernameField,
|
||||
DisplayNameField: req.DisplayNameField,
|
||||
EmailField: req.EmailField,
|
||||
WellKnown: req.WellKnown,
|
||||
AuthStyle: req.AuthStyle,
|
||||
}
|
||||
|
||||
if err := model.CreateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Register the provider in the OAuth registry
|
||||
oauth.RegisterOrUpdateCustomProvider(provider)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "创建成功",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
|
||||
type UpdateCustomOAuthProviderRequest struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"user_info_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
UserIdField string `json:"user_id_field"`
|
||||
UsernameField string `json:"username_field"`
|
||||
DisplayNameField string `json:"display_name_field"`
|
||||
EmailField string `json:"email_field"`
|
||||
WellKnown string `json:"well_known"`
|
||||
AuthStyle int `json:"auth_style"`
|
||||
}
|
||||
|
||||
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
|
||||
func UpdateCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateCustomOAuthProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing provider
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
oldSlug := provider.Slug
|
||||
|
||||
// Check if new slug is taken by another provider
|
||||
if req.Slug != "" && req.Slug != provider.Slug {
|
||||
if model.IsSlugTaken(req.Slug, id) {
|
||||
common.ApiErrorMsg(c, "该 Slug 已被使用")
|
||||
return
|
||||
}
|
||||
// Check if slug conflicts with built-in providers
|
||||
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
|
||||
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Name != "" {
|
||||
provider.Name = req.Name
|
||||
}
|
||||
if req.Slug != "" {
|
||||
provider.Slug = req.Slug
|
||||
}
|
||||
provider.Enabled = req.Enabled
|
||||
if req.ClientId != "" {
|
||||
provider.ClientId = req.ClientId
|
||||
}
|
||||
if req.ClientSecret != "" {
|
||||
provider.ClientSecret = req.ClientSecret
|
||||
}
|
||||
if req.AuthorizationEndpoint != "" {
|
||||
provider.AuthorizationEndpoint = req.AuthorizationEndpoint
|
||||
}
|
||||
if req.TokenEndpoint != "" {
|
||||
provider.TokenEndpoint = req.TokenEndpoint
|
||||
}
|
||||
if req.UserInfoEndpoint != "" {
|
||||
provider.UserInfoEndpoint = req.UserInfoEndpoint
|
||||
}
|
||||
if req.Scopes != "" {
|
||||
provider.Scopes = req.Scopes
|
||||
}
|
||||
if req.UserIdField != "" {
|
||||
provider.UserIdField = req.UserIdField
|
||||
}
|
||||
if req.UsernameField != "" {
|
||||
provider.UsernameField = req.UsernameField
|
||||
}
|
||||
if req.DisplayNameField != "" {
|
||||
provider.DisplayNameField = req.DisplayNameField
|
||||
}
|
||||
if req.EmailField != "" {
|
||||
provider.EmailField = req.EmailField
|
||||
}
|
||||
provider.WellKnown = req.WellKnown
|
||||
provider.AuthStyle = req.AuthStyle
|
||||
|
||||
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the provider in the OAuth registry
|
||||
if oldSlug != provider.Slug {
|
||||
oauth.UnregisterCustomProvider(oldSlug)
|
||||
}
|
||||
oauth.RegisterOrUpdateCustomProvider(provider)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "更新成功",
|
||||
"data": toCustomOAuthProviderResponse(provider),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteCustomOAuthProvider deletes a custom OAuth provider
|
||||
func DeleteCustomOAuthProvider(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的 ID")
|
||||
return
|
||||
}
|
||||
|
||||
// Get existing provider to get slug
|
||||
provider, err := model.GetCustomOAuthProviderById(id)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if there are any user bindings
|
||||
count, _ := model.GetBindingCountByProviderId(id)
|
||||
if count > 0 {
|
||||
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeleteCustomOAuthProvider(id); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Unregister the provider from the OAuth registry
|
||||
oauth.UnregisterCustomProvider(provider.Slug)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "删除成功",
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserOAuthBindings returns all OAuth bindings for the current user
|
||||
func GetUserOAuthBindings(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
common.ApiErrorMsg(c, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build response with provider info
|
||||
type BindingResponse struct {
|
||||
ProviderId int `json:"provider_id"`
|
||||
ProviderName string `json:"provider_name"`
|
||||
ProviderSlug string `json:"provider_slug"`
|
||||
ProviderUserId string `json:"provider_user_id"`
|
||||
}
|
||||
|
||||
response := make([]BindingResponse, 0)
|
||||
for _, binding := range bindings {
|
||||
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
|
||||
if err != nil {
|
||||
continue // Skip if provider not found
|
||||
}
|
||||
response = append(response, BindingResponse{
|
||||
ProviderId: binding.ProviderId,
|
||||
ProviderName: provider.Name,
|
||||
ProviderSlug: provider.Slug,
|
||||
ProviderUserId: binding.ProviderUserId,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
// UnbindCustomOAuth unbinds a custom OAuth provider from the current user
|
||||
func UnbindCustomOAuth(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
common.ApiErrorMsg(c, "未登录")
|
||||
return
|
||||
}
|
||||
|
||||
providerIdStr := c.Param("provider_id")
|
||||
providerId, err := strconv.Atoi(providerIdStr)
|
||||
if err != nil {
|
||||
common.ApiErrorMsg(c, "无效的提供商 ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "解绑成功",
|
||||
})
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/middleware"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/oauth"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/console_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
@@ -129,6 +130,30 @@ func GetStatus(c *gin.Context) {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
// Add enabled custom OAuth providers
|
||||
customProviders := oauth.GetEnabledCustomProviders()
|
||||
if len(customProviders) > 0 {
|
||||
type CustomOAuthInfo struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
ClientId string `json:"client_id"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
Scopes string `json:"scopes"`
|
||||
}
|
||||
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
|
||||
for _, p := range customProviders {
|
||||
config := p.GetConfig()
|
||||
providersInfo = append(providersInfo, CustomOAuthInfo{
|
||||
Name: config.Name,
|
||||
Slug: config.Slug,
|
||||
ClientId: config.ClientId,
|
||||
AuthorizationEndpoint: config.AuthorizationEndpoint,
|
||||
Scopes: config.Scopes,
|
||||
})
|
||||
}
|
||||
data["custom_oauth_providers"] = providersInfo
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -171,12 +171,22 @@ func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
|
||||
return
|
||||
}
|
||||
|
||||
// Update user with OAuth ID
|
||||
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
// Handle binding based on provider type
|
||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||
// Custom provider: use user_oauth_bindings table
|
||||
err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Built-in provider: update user record directly
|
||||
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
|
||||
@@ -188,7 +198,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
|
||||
// Check if user already exists with new ID
|
||||
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -203,7 +212,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
// Try to find user with legacy ID (for GitHub migration from login to numeric ID)
|
||||
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
|
||||
if provider.IsUserIDTaken(legacyID) {
|
||||
provider.SetProviderUserID(user, legacyID)
|
||||
err := provider.FillUserByProviderID(user, legacyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -240,7 +248,6 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
|
||||
// Handle affiliate code
|
||||
affCode := session.Get("aff")
|
||||
@@ -253,6 +260,25 @@ func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *o
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For custom providers, create the binding after user is created
|
||||
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
|
||||
binding := &model.UserOAuthBinding{
|
||||
UserId: user.Id,
|
||||
ProviderId: genericProvider.GetProviderId(),
|
||||
ProviderUserId: oauthUser.ProviderUserID,
|
||||
}
|
||||
if err := model.CreateUserOAuthBinding(binding); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
|
||||
// Don't fail the registration, just log the error
|
||||
}
|
||||
} else {
|
||||
// Built-in provider: set the provider user ID on the user model
|
||||
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
|
||||
if err := user.Update(false); err != nil {
|
||||
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user