first commit: one-api base code + SAAS plan document
This commit is contained in:
240
controller/auth/github.go
Normal file
240
controller/auth/github.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !config.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(ctx, 0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !config.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := random.GetRandomString(12)
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
203
controller/auth/lark.go
Normal file
203
controller/auth/lark.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
type LarkOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
|
||||
type LarkUser struct {
|
||||
Name string `json:"name"`
|
||||
OpenID string `json:"open_id"`
|
||||
}
|
||||
|
||||
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{
|
||||
"client_id": config.LarkClientId,
|
||||
"client_secret": config.LarkClientSecret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse LarkOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
|
||||
}
|
||||
var larkUser LarkUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&larkUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &larkUser, nil
|
||||
}
|
||||
|
||||
func LarkOAuth(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LarkBind(c)
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
larkUser, err := getLarkUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LarkId: larkUser.OpenID,
|
||||
}
|
||||
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||
err := user.FillUserByLarkId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if larkUser.Name != "" {
|
||||
user.DisplayName = larkUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Lark User"
|
||||
}
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(ctx, 0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func LarkBind(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
larkUser, err := getLarkUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
LarkId: larkUser.OpenID,
|
||||
}
|
||||
if model.IsLarkIdAlreadyTaken(user.LarkId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该飞书账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.LarkId = larkUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
228
controller/auth/oidc.go
Normal file
228
controller/auth/oidc.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{
|
||||
"client_id": config.OidcClientId,
|
||||
"client_secret": config.OidcClientSecret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
|
||||
}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !config.OidcEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(ctx, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !config.OidcEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
169
controller/auth/wechat.go
Normal file
169
controller/auth/wechat.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/controller"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
type wechatLoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func getWeChatIdByCode(code string) (string, error) {
|
||||
if code == "" {
|
||||
return "", errors.New("无效的参数")
|
||||
}
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", config.WeChatServerToken)
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
httpResponse, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer httpResponse.Body.Close()
|
||||
var res wechatLoginResponse
|
||||
err = json.NewDecoder(httpResponse.Body).Decode(&res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !res.Success {
|
||||
return "", errors.New(res.Message)
|
||||
}
|
||||
if res.Data == "" {
|
||||
return "", errors.New("验证码错误或已过期")
|
||||
}
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func WeChatAuth(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if !config.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员未开启通过微信登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
WeChatId: wechatId,
|
||||
}
|
||||
if model.IsWeChatIdAlreadyTaken(wechatId) {
|
||||
err := user.FillUserByWeChatId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if config.RegisterEnabled {
|
||||
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = "WeChat User"
|
||||
user.Role = model.RoleCommonUser
|
||||
user.Status = model.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(ctx, 0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != model.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
func WeChatBind(c *gin.Context) {
|
||||
if !config.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员未开启通过微信登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if model.IsWeChatIdAlreadyTaken(wechatId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该微信账号已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user := model.User{
|
||||
Id: id,
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.WeChatId = wechatId
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
97
controller/billing.go
Normal file
97
controller/billing.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func GetSubscription(c *gin.Context) {
|
||||
var remainQuota int64
|
||||
var usedQuota int64
|
||||
var err error
|
||||
var token *model.Token
|
||||
var expiredTime int64
|
||||
if config.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
if err == nil {
|
||||
expiredTime = token.ExpiredTime
|
||||
remainQuota = token.RemainQuota
|
||||
usedQuota = token.UsedQuota
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
remainQuota, err = model.GetUserQuota(userId)
|
||||
if err != nil {
|
||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
}
|
||||
if expiredTime <= 0 {
|
||||
expiredTime = 0
|
||||
}
|
||||
if err != nil {
|
||||
Error := relaymodel.Error{
|
||||
Message: err.Error(),
|
||||
Type: "upstream_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": Error,
|
||||
})
|
||||
return
|
||||
}
|
||||
quota := remainQuota + usedQuota
|
||||
amount := float64(quota)
|
||||
if config.DisplayInCurrencyEnabled {
|
||||
amount /= config.QuotaPerUnit
|
||||
}
|
||||
if token != nil && token.UnlimitedQuota {
|
||||
amount = 100000000
|
||||
}
|
||||
subscription := OpenAISubscriptionResponse{
|
||||
Object: "billing_subscription",
|
||||
HasPaymentMethod: true,
|
||||
SoftLimitUSD: amount,
|
||||
HardLimitUSD: amount,
|
||||
SystemHardLimitUSD: amount,
|
||||
AccessUntil: expiredTime,
|
||||
}
|
||||
c.JSON(200, subscription)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUsage(c *gin.Context) {
|
||||
var quota int64
|
||||
var err error
|
||||
var token *model.Token
|
||||
if config.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
quota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
quota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if err != nil {
|
||||
Error := relaymodel.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": Error,
|
||||
})
|
||||
return
|
||||
}
|
||||
amount := float64(quota)
|
||||
if config.DisplayInCurrencyEnabled {
|
||||
amount /= config.QuotaPerUnit
|
||||
}
|
||||
usage := OpenAIUsageResponse{
|
||||
Object: "list",
|
||||
TotalUsage: amount * 100,
|
||||
}
|
||||
c.JSON(200, usage)
|
||||
return
|
||||
}
|
||||
459
controller/channel-billing.go
Normal file
459
controller/channel-billing.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://github.com/songquanpeng/one-api/issues/79
|
||||
|
||||
type OpenAISubscriptionResponse struct {
|
||||
Object string `json:"object"`
|
||||
HasPaymentMethod bool `json:"has_payment_method"`
|
||||
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||||
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||||
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||||
AccessUntil int64 `json:"access_until"`
|
||||
}
|
||||
|
||||
type OpenAIUsageDailyCost struct {
|
||||
Timestamp float64 `json:"timestamp"`
|
||||
LineItems []struct {
|
||||
Name string `json:"name"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
}
|
||||
|
||||
type OpenAICreditGrants struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
||||
|
||||
type OpenAIUsageResponse struct {
|
||||
Object string `json:"object"`
|
||||
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||
}
|
||||
|
||||
type OpenAISBUsageResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
Credit string `json:"credit"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type AIProxyUserOverviewResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
ErrorCode int `json:"error_code"`
|
||||
Data struct {
|
||||
TotalPoints float64 `json:"totalPoints"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type API2GPTUsageResponse struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalRemaining float64 `json:"total_remaining"`
|
||||
}
|
||||
|
||||
type APGC2DGPTUsageResponse struct {
|
||||
//Grants interface{} `json:"grants"`
|
||||
Object string `json:"object"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
}
|
||||
|
||||
type SiliconFlowUsageResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status bool `json:"status"`
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image"`
|
||||
Email string `json:"email"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Balance string `json:"balance"`
|
||||
Status string `json:"status"`
|
||||
Introduction string `json:"introduction"`
|
||||
Role string `json:"role"`
|
||||
ChargeBalance string `json:"chargeBalance"`
|
||||
TotalBalance string `json:"totalBalance"`
|
||||
Category string `json:"category"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type DeepSeekUsageResponse struct {
|
||||
IsAvailable bool `json:"is_available"`
|
||||
BalanceInfos []struct {
|
||||
Currency string `json:"currency"`
|
||||
TotalBalance string `json:"total_balance"`
|
||||
GrantedBalance string `json:"granted_balance"`
|
||||
ToppedUpBalance string `json:"topped_up_balance"`
|
||||
} `json:"balance_infos"`
|
||||
}
|
||||
|
||||
type OpenRouterResponse struct {
|
||||
Data struct {
|
||||
TotalCredits float64 `json:"total_credits"`
|
||||
TotalUsage float64 `json:"total_usage"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// GetAuthHeader get auth header
|
||||
func GetAuthHeader(token string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return h
|
||||
}
|
||||
|
||||
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := client.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
||||
}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
||||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenAICreditGrants{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
||||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
||||
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenAISBUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if response.Data == nil {
|
||||
return 0, errors.New(response.Msg)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://aiproxy.io/api/report/getUserOverview"
|
||||
headers := http.Header{}
|
||||
headers.Add("Api-Key", channel.Key)
|
||||
body, err := GetResponseBody("GET", url, channel, headers)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := AIProxyUserOverviewResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Success {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
||||
}
|
||||
channel.UpdateBalance(response.Data.TotalPoints)
|
||||
return response.Data.TotalPoints, nil
|
||||
}
|
||||
|
||||
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := API2GPTUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalRemaining)
|
||||
return response.TotalRemaining, nil
|
||||
}
|
||||
|
||||
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := APGC2DGPTUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
||||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.siliconflow.cn/v1/user/info"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := SiliconFlowUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if response.Code != 20000 {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.deepseek.com/user/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := DeepSeekUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
index := -1
|
||||
for i, balanceInfo := range response.BalanceInfos {
|
||||
if balanceInfo.Currency == "CNY" {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index == -1 {
|
||||
return 0, errors.New("currency CNY not found")
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://openrouter.ai/api/v1/credits"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenRouterResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
balance := response.Data.TotalCredits - response.Data.TotalUsage
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := channeltype.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case channeltype.OpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case channeltype.Azure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case channeltype.Custom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
case channeltype.CloseAI:
|
||||
return updateChannelCloseAIBalance(channel)
|
||||
case channeltype.OpenAISB:
|
||||
return updateChannelOpenAISBBalance(channel)
|
||||
case channeltype.AIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case channeltype.API2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case channeltype.AIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case channeltype.SiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case channeltype.DeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case channeltype.OpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
||||
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
subscription := OpenAISubscriptionResponse{}
|
||||
err = json.Unmarshal(body, &subscription)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
||||
endDate := now.Format("2006-01-02")
|
||||
if !subscription.HasPaymentMethod {
|
||||
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
||||
}
|
||||
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
usage := OpenAIUsageResponse{}
|
||||
err = json.Unmarshal(body, &usage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func UpdateChannelBalance(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"balance": balance,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func updateAllChannelsBalance() error {
|
||||
channels, err := model.GetAllChannels(0, 0, "all")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, channel := range channels {
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
// TODO: support Azure
|
||||
if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
|
||||
continue
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(config.RequestInterval)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAllChannelsBalance(c *gin.Context) {
|
||||
//err := updateAllChannelsBalance()
|
||||
//if err != nil {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": err.Error(),
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyUpdateChannels(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
logger.SysLog("updating all channels")
|
||||
_ = updateAllChannelsBalance()
|
||||
logger.SysLog("channels update done")
|
||||
}
|
||||
}
|
||||
305
controller/channel-test.go
Normal file
305
controller/channel-test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/message"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/controller"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
|
||||
if model == "" {
|
||||
model = "gpt-3.5-turbo"
|
||||
}
|
||||
testRequest := &relaymodel.GeneralOpenAIRequest{
|
||||
Model: model,
|
||||
}
|
||||
testMessage := relaymodel.Message{
|
||||
Role: "user",
|
||||
Content: config.TestPrompt,
|
||||
}
|
||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||
return testRequest
|
||||
}
|
||||
|
||||
func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
|
||||
var response openai.TextResponse
|
||||
err := json.Unmarshal([]byte(resp), &response)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if len(response.Choices) == 0 {
|
||||
return nil, "", errors.New("response has no choices")
|
||||
}
|
||||
stringContent, ok := response.Choices[0].Content.(string)
|
||||
if !ok {
|
||||
return nil, "", errors.New("response content is not string")
|
||||
}
|
||||
return &response, stringContent, nil
|
||||
}
|
||||
|
||||
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
|
||||
startTime := time.Now()
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set(ctxkey.Channel, channel.Type)
|
||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||
cfg, _ := channel.LoadConfig()
|
||||
c.Set(ctxkey.Config, cfg)
|
||||
middleware.SetupContextForSelectedChannel(c, channel, "")
|
||||
meta := meta.GetByContext(c)
|
||||
apiType := channeltype.ToAPIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
modelName := request.Model
|
||||
modelMap := channel.GetModelMapping()
|
||||
if modelName == "" || !strings.Contains(channel.Models, modelName) {
|
||||
modelNames := strings.Split(channel.Models, ",")
|
||||
if len(modelNames) > 0 {
|
||||
modelName = modelNames[0]
|
||||
}
|
||||
}
|
||||
if modelMap != nil && modelMap[modelName] != "" {
|
||||
modelName = modelMap[modelName]
|
||||
}
|
||||
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
|
||||
request.Model = modelName
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
defer func() {
|
||||
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
|
||||
if err != nil || openaiErr != nil {
|
||||
errorMessage := ""
|
||||
if err != nil {
|
||||
errorMessage = err.Error()
|
||||
} else {
|
||||
errorMessage = openaiErr.Message
|
||||
}
|
||||
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
|
||||
}
|
||||
go model.RecordTestLog(ctx, &model.Log{
|
||||
ChannelId: channel.Id,
|
||||
ModelName: modelName,
|
||||
Content: logContent,
|
||||
ElapsedTime: helper.CalcElapsedTime(startTime),
|
||||
})
|
||||
}()
|
||||
logger.SysLog(string(jsonData))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||
err := controller.RelayErrorHandler(resp)
|
||||
errorMessage := err.Error.Message
|
||||
if errorMessage != "" {
|
||||
errorMessage = ", error message: " + errorMessage
|
||||
}
|
||||
return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
}
|
||||
if usage == nil {
|
||||
return "", errors.New("usage is nil"), nil
|
||||
}
|
||||
rawResponse := w.Body.String()
|
||||
_, responseMessage, err = parseTestResponse(rawResponse)
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to parse error: %s, \nresponse: %s", err.Error(), rawResponse))
|
||||
return "", err, nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return responseMessage, nil, nil
|
||||
}
|
||||
|
||||
func TestChannel(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
modelName := c.Query("model")
|
||||
testRequest := buildTestRequest(modelName)
|
||||
tik := time.Now()
|
||||
responseMessage, err, _ := testChannel(ctx, channel, testRequest)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if err != nil {
|
||||
milliseconds = 0
|
||||
}
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"time": consumedTime,
|
||||
"modelName": modelName,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": responseMessage,
|
||||
"time": consumedTime,
|
||||
"modelName": modelName,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
func testChannels(ctx context.Context, notify bool, scope string) error {
|
||||
if config.RootUserEmail == "" {
|
||||
config.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
testAllChannelsLock.Lock()
|
||||
if testAllChannelsRunning {
|
||||
testAllChannelsLock.Unlock()
|
||||
return errors.New("测试已在运行中")
|
||||
}
|
||||
testAllChannelsRunning = true
|
||||
testAllChannelsLock.Unlock()
|
||||
channels, err := model.GetAllChannels(0, 0, scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
}
|
||||
go func() {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
testRequest := buildTestRequest("")
|
||||
_, err, openaiErr := testChannel(ctx, channel, testRequest)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if isChannelEnabled && milliseconds > disableThreshold {
|
||||
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||
if config.AutomaticDisableChannelEnabled {
|
||||
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
} else {
|
||||
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error())
|
||||
}
|
||||
}
|
||||
if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
|
||||
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
|
||||
monitor.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(config.RequestInterval)
|
||||
}
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
if notify {
|
||||
err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常")
|
||||
if err != nil {
|
||||
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestChannels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
scope := c.Query("scope")
|
||||
if scope == "" {
|
||||
scope = "all"
|
||||
}
|
||||
err := testChannels(ctx, true, scope)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
ctx := context.Background()
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
logger.SysLog("testing all channels")
|
||||
_ = testChannels(ctx, false, "all")
|
||||
logger.SysLog("channel test finished")
|
||||
}
|
||||
}
|
||||
172
controller/channel.go
Normal file
172
controller/channel.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, "limited")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
channels, err := model.SearchChannels(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channel,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.CreatedTime = helper.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
localChannel := channel
|
||||
localChannel.Key = key
|
||||
channels = append(channels, localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteChannel(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
channel := model.Channel{Id: id}
|
||||
err := channel.Delete()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteDisabledChannel(c *gin.Context) {
|
||||
rows, err := model.DeleteDisabledChannel()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channel,
|
||||
})
|
||||
return
|
||||
}
|
||||
19
controller/group.go
Normal file
19
controller/group.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName := range billingratio.GroupRatio {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": groupNames,
|
||||
})
|
||||
}
|
||||
169
controller/log.go
Normal file
169
controller/log.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
username := c.Query("username")
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsStat(c *gin.Context) {
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"quota": quotaNum,
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsSelfStat(c *gin.Context) {
|
||||
username := c.GetString(ctxkey.Username)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"quota": quotaNum,
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteHistoryLogs(c *gin.Context) {
|
||||
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||
if targetTimestamp == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "target timestamp is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(targetTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
return
|
||||
}
|
||||
232
controller/misc.go
Normal file
232
controller/misc.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/i18n"
|
||||
"github.com/songquanpeng/one-api/common/message"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": config.EmailVerificationEnabled,
|
||||
"github_oauth": config.GitHubOAuthEnabled,
|
||||
"github_client_id": config.GitHubClientId,
|
||||
"lark_client_id": config.LarkClientId,
|
||||
"system_name": config.SystemName,
|
||||
"logo": config.Logo,
|
||||
"footer_html": config.Footer,
|
||||
"wechat_qrcode": config.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": config.WeChatAuthEnabled,
|
||||
"server_address": config.ServerAddress,
|
||||
"turnstile_check": config.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": config.TurnstileSiteKey,
|
||||
"top_up_link": config.TopUpLink,
|
||||
"chat_link": config.ChatLink,
|
||||
"quota_per_unit": config.QuotaPerUnit,
|
||||
"display_in_currency": config.DisplayInCurrencyEnabled,
|
||||
"oidc": config.OidcEnabled,
|
||||
"oidc_client_id": config.OidcClientId,
|
||||
"oidc_well_known": config.OidcWellKnown,
|
||||
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
|
||||
"oidc_token_endpoint": config.OidcTokenEndpoint,
|
||||
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetNotice(c *gin.Context) {
|
||||
config.OptionMapRWMutex.RLock()
|
||||
defer config.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": config.OptionMap["Notice"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAbout(c *gin.Context) {
|
||||
config.OptionMapRWMutex.RLock()
|
||||
defer config.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": config.OptionMap["About"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomePageContent(c *gin.Context) {
|
||||
config.OptionMapRWMutex.RLock()
|
||||
defer config.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": config.OptionMap["HomePageContent"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if config.EmailDomainRestrictionEnabled {
|
||||
allowed := false
|
||||
for _, domain := range config.EmailDomainWhitelist {
|
||||
if strings.HasSuffix(email, "@"+domain) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "邮箱地址已被占用",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(6)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
|
||||
subject := fmt.Sprintf("%s 邮箱验证邮件", config.SystemName)
|
||||
content := message.EmailTemplate(
|
||||
subject,
|
||||
fmt.Sprintf(`
|
||||
<p>您好!</p>
|
||||
<p>您正在进行 %s 邮箱验证。</p>
|
||||
<p>您的验证码为:</p>
|
||||
<p style="font-size: 24px; font-weight: bold; color: #333; background-color: #f8f8f8; padding: 10px; text-align: center; border-radius: 4px;">%s</p>
|
||||
<p style="color: #666;">验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>
|
||||
`, config.SystemName, code, common.VerificationValidMinutes),
|
||||
)
|
||||
err := message.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendPasswordResetEmail(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该邮箱地址未注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s 密码重置", config.SystemName)
|
||||
content := message.EmailTemplate(
|
||||
subject,
|
||||
fmt.Sprintf(`
|
||||
<p>您好!</p>
|
||||
<p>您正在进行 %s 密码重置。</p>
|
||||
<p>请点击下面的按钮进行密码重置:</p>
|
||||
<p style="text-align: center; margin: 30px 0;">
|
||||
<a href="%s" style="background-color: #007bff; color: white; padding: 12px 24px; text-decoration: none; border-radius: 4px; display: inline-block;">重置密码</a>
|
||||
</p>
|
||||
<p style="color: #666;">如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||
<p style="background-color: #f8f8f8; padding: 10px; border-radius: 4px; word-break: break-all;">%s</p>
|
||||
<p style="color: #666;">重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>
|
||||
`, config.SystemName, link, link, common.VerificationValidMinutes),
|
||||
)
|
||||
err := message.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("%s%s", i18n.Translate(c, "send_email_failed"), err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
Email string `json:"email"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ResetPassword(c *gin.Context) {
|
||||
var req PasswordResetRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&req)
|
||||
if req.Email == "" || req.Token == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "重置链接非法或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
password := common.GenerateVerificationCode(12)
|
||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": password,
|
||||
})
|
||||
return
|
||||
}
|
||||
213
controller/model.go
Normal file
213
controller/model.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relay "github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
var models []OpenAIModels
|
||||
var modelsMap map[string]OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func init() {
|
||||
var permission []OpenAIModelPermission
|
||||
permission = append(permission, OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
AllowCreateEngine: true,
|
||||
AllowSampling: true,
|
||||
AllowLogprobs: true,
|
||||
AllowSearchIndices: false,
|
||||
AllowView: true,
|
||||
AllowFineTuning: false,
|
||||
Organization: "*",
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
for i := 0; i < apitype.Dummy; i++ {
|
||||
if i == apitype.AIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
models = append(models, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, channelType := range openai.CompatibleChannels {
|
||||
if channelType == channeltype.Azure {
|
||||
continue
|
||||
}
|
||||
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
||||
for _, modelName := range channelModelList {
|
||||
models = append(models, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
modelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range models {
|
||||
modelsMap[model.Id] = model
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i < channeltype.Dummy; i++ {
|
||||
adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
|
||||
meta := &meta.Meta{
|
||||
ChannelType: i,
|
||||
}
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
}
|
||||
|
||||
func DashboardListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelId2Models,
|
||||
})
|
||||
}
|
||||
|
||||
func ListAllModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var availableModels []string
|
||||
if c.GetString(ctxkey.AvailableModels) != "" {
|
||||
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
||||
}
|
||||
modelSet := make(map[string]bool)
|
||||
for _, availableModel := range availableModels {
|
||||
modelSet[availableModel] = true
|
||||
}
|
||||
availableOpenAIModels := make([]OpenAIModels, 0)
|
||||
for _, model := range models {
|
||||
if _, ok := modelSet[model.Id]; ok {
|
||||
modelSet[model.Id] = false
|
||||
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||
}
|
||||
}
|
||||
for modelName, ok := range modelSet {
|
||||
if ok {
|
||||
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": availableOpenAIModels,
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
modelId := c.Param("model")
|
||||
if model, ok := modelsMap[modelId]; ok {
|
||||
c.JSON(200, model)
|
||||
} else {
|
||||
Error := relaymodel.Error{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
Type: "invalid_request_error",
|
||||
Param: "model",
|
||||
Code: "model_not_found",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserAvailableModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
userGroup, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
102
controller/option.go
Normal file
102
controller/option.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/i18n"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetOptions(c *gin.Context) {
|
||||
var options []*model.Option
|
||||
config.OptionMapRWMutex.Lock()
|
||||
for k, v := range config.OptionMap {
|
||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
Key: k,
|
||||
Value: helper.Interface2String(v),
|
||||
})
|
||||
}
|
||||
config.OptionMapRWMutex.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": options,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateOption(c *gin.Context) {
|
||||
var option model.Option
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
switch option.Key {
|
||||
case "Theme":
|
||||
if !config.ValidThemes[option.Value] {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的主题",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GitHubOAuthEnabled":
|
||||
if option.Value == "true" && config.GitHubClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "WeChatAuthEnabled":
|
||||
if option.Value == "true" && config.WeChatServerAddress == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "TurnstileCheckEnabled":
|
||||
if option.Value == "true" && config.TurnstileSiteKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
195
controller/redemption.go
Normal file
195
controller/redemption.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetAllRedemptions(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": redemptions,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchRedemptions(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
redemptions, err := model.SearchRedemptions(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": redemptions,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetRedemption(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
redemption, err := model.GetRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": redemption,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AddRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
})
|
||||
return
|
||||
}
|
||||
if redemption.Count <= 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码个数必须大于0",
|
||||
})
|
||||
return
|
||||
}
|
||||
if redemption.Count > 100 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "一次兑换码批量生成的个数不能大于 100",
|
||||
})
|
||||
return
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := random.GetUUID()
|
||||
cleanRedemption := model.Redemption{
|
||||
UserId: c.GetInt(ctxkey.Id),
|
||||
Name: redemption.Name,
|
||||
Key: key,
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
Quota: redemption.Quota,
|
||||
}
|
||||
err = cleanRedemption.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"data": keys,
|
||||
})
|
||||
return
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": keys,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteRedemption(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
err := model.DeleteRedemptionById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateRedemption(c *gin.Context) {
|
||||
statusOnly := c.Query("status_only")
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
} else {
|
||||
// If you add more fields, please also update redemption.Update()
|
||||
cleanRedemption.Name = redemption.Name
|
||||
cleanRedemption.Quota = redemption.Quota
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanRedemption,
|
||||
})
|
||||
return
|
||||
}
|
||||
156
controller/relay.go
Normal file
156
controller/relay.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
dbmodel "github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/monitor"
|
||||
"github.com/songquanpeng/one-api/relay/controller"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||
var err *model.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case relaymode.ImagesGenerations:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
case relaymode.AudioSpeech:
|
||||
fallthrough
|
||||
case relaymode.AudioTranslation:
|
||||
fallthrough
|
||||
case relaymode.AudioTranscription:
|
||||
err = controller.RelayAudioHelper(c, relayMode)
|
||||
case relaymode.Proxy:
|
||||
err = controller.RelayProxyHelper(c, relayMode)
|
||||
default:
|
||||
err = controller.RelayTextHelper(c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
relayMode := relaymode.GetByPath(c.Request.URL.Path)
|
||||
if config.DebugEnabled {
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
logger.Debugf(ctx, "request body: %s", string(requestBody))
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
bizErr := relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
monitor.Emit(channelId, true)
|
||||
return
|
||||
}
|
||||
lastFailedChannelId := channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
group := c.GetString(ctxkey.Group)
|
||||
originalModel := c.GetString(ctxkey.OriginalModel)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||
requestId := c.GetString(helper.RequestIdKey)
|
||||
retryTimes := config.RetryTimes
|
||||
if !shouldRetry(c, bizErr.StatusCode) {
|
||||
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := retryTimes; i > 0; i-- {
|
||||
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
|
||||
break
|
||||
}
|
||||
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
|
||||
if channel.Id == lastFailedChannelId {
|
||||
continue
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
bizErr = relayHelper(c, relayMode)
|
||||
if bizErr == nil {
|
||||
return
|
||||
}
|
||||
channelId := c.GetInt(ctxkey.ChannelId)
|
||||
lastFailedChannelId = channelId
|
||||
channelName := c.GetString(ctxkey.ChannelName)
|
||||
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr)
|
||||
}
|
||||
if bizErr != nil {
|
||||
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
|
||||
// BUG: bizErr is in race condition
|
||||
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
|
||||
c.JSON(bizErr.StatusCode, gin.H{
|
||||
"error": bizErr.Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
if _, ok := c.Get(ctxkey.SpecificChannelId); ok {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if statusCode/100 == 5 {
|
||||
return true
|
||||
}
|
||||
if statusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if statusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) {
|
||||
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
monitor.DisableChannel(channelId, channelName, err.Message)
|
||||
} else {
|
||||
monitor.Emit(channelId, false)
|
||||
}
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := model.Error{
|
||||
Message: "API not implemented",
|
||||
Type: "one_api_error",
|
||||
Param: "",
|
||||
Code: "api_not_implemented",
|
||||
}
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := model.Error{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
Code: "",
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
257
controller/token.go
Normal file
257
controller/token.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
|
||||
order := c.Query("order")
|
||||
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchTokens(c *gin.Context) {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
keyword := c.Query("keyword")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": token,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetTokenStatus(c *gin.Context) {
|
||||
tokenId := c.GetInt(ctxkey.TokenId)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
expiredAt := token.ExpiredTime
|
||||
if expiredAt == -1 {
|
||||
expiredAt = 0
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "credit_summary",
|
||||
"total_granted": token.RemainQuota,
|
||||
"total_used": 0, // not supported currently
|
||||
"total_available": token.RemainQuota,
|
||||
"expires_at": expiredAt * 1000,
|
||||
})
|
||||
}
|
||||
|
||||
func validateToken(c *gin.Context, token model.Token) error {
|
||||
if len(token.Name) > 30 {
|
||||
return fmt.Errorf("令牌名称过长")
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
err := network.IsValidSubnets(*token.Subnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无效的网段:%s", err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = validateToken(c, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cleanToken := model.Token{
|
||||
UserId: c.GetInt(ctxkey.Id),
|
||||
Name: token.Name,
|
||||
Key: random.GenerateKey(),
|
||||
CreatedTime: helper.GetTimestamp(),
|
||||
AccessedTime: helper.GetTimestamp(),
|
||||
ExpiredTime: token.ExpiredTime,
|
||||
RemainQuota: token.RemainQuota,
|
||||
UnlimitedQuota: token.UnlimitedQuota,
|
||||
Models: token.Models,
|
||||
Subnet: token.Subnet,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteToken(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateToken(c *gin.Context) {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
statusOnly := c.Query("status_only")
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = validateToken(c, token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("参数错误:%s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if token.Status == model.TokenStatusEnabled {
|
||||
if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanToken.Status = token.Status
|
||||
} else {
|
||||
// If you add more fields, please also update token.Update()
|
||||
cleanToken.Name = token.Name
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.Models = token.Models
|
||||
cleanToken.Subnet = token.Subnet
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
816
controller/user.go
Normal file
816
controller/user.go
Normal file
@@ -0,0 +1,816 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/i18n"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func Login(c *gin.Context) {
|
||||
if !config.PasswordLoginEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了密码登录",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
var loginRequest LoginRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&loginRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
username := loginRequest.Username
|
||||
password := loginRequest.Password
|
||||
if username == "" || password == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
err = user.ValidateAndFill()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
SetupLogin(&user, c)
|
||||
}
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func SetupLogin(user *model.User, c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanUser := model.User{
|
||||
Id: user.Id,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
"success": true,
|
||||
"data": cleanUser,
|
||||
})
|
||||
}
|
||||
|
||||
func Logout(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Clear()
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if !config.RegisterEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了新用户注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if !config.PasswordRegisterEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_input"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if config.EmailVerificationEnabled {
|
||||
if user.Email == "" || user.VerificationCode == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员开启了邮箱验证,请输入邮箱地址和验证码",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码错误或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
affCode := user.AffCode // this code is the inviter's code, not the user's own code
|
||||
inviterId, _ := model.GetUserIdByAffCode(affCode)
|
||||
cleanUser := model.User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
}
|
||||
if config.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
}
|
||||
if err := cleanUser.Insert(ctx, inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
|
||||
order := c.DefaultQuery("order", "")
|
||||
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": users,
|
||||
})
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
users, err := model.SearchUsers(keyword)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": users,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt(ctxkey.Role)
|
||||
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权获取同级或更高等级用户的信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserDashboard(c *gin.Context) {
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
now := time.Now()
|
||||
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
|
||||
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
|
||||
|
||||
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法获取统计信息",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dashboards,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.AccessToken = random.GetUUID()
|
||||
|
||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请重试,系统生成的 UUID 竟然重复了!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user.AccessToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
user.AffCode = random.GetRandomString(4)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user.AffCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt(ctxkey.Id)
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateUser(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var updatedUser model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
|
||||
if err != nil || updatedUser.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if updatedUser.Password == "" {
|
||||
updatedUser.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
if err := common.Validate.Struct(&updatedUser); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_input"),
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt(ctxkey.Role)
|
||||
if myRole <= originUser.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
if updatedUser.Password == "$I_LOVE_U" {
|
||||
updatedUser.Password = "" // rollback to what it should be
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Update(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Password == "" {
|
||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "输入不合法 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cleanUser := model.User{
|
||||
Id: c.GetInt(ctxkey.Id),
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if user.Password == "$I_LOVE_U" {
|
||||
user.Password = "" // rollback to what it should be
|
||||
cleanUser.Password = ""
|
||||
}
|
||||
updatePassword := user.Password != ""
|
||||
if err := cleanUser.Update(updatePassword); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= originUser.Role {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权删除同权限等级或更高权限等级的用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不能删除超级管理员账户",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func CreateUser(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
if err != nil || user.Username == "" || user.Password == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_input"),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.DisplayName == "" {
|
||||
user.DisplayName = user.Username
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if user.Role >= myRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法创建权限大于等于自己的用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Even for admin users, we cannot fully trust them!
|
||||
cleanUser := model.User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if err := cleanUser.Insert(ctx, 0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type ManageRequest struct {
|
||||
Username string `json:"username"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
// ManageUser Only admin user can do this
|
||||
func ManageUser(c *gin.Context) {
|
||||
var req ManageRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&req)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": i18n.Translate(c, "invalid_parameter"),
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
Username: req.Username,
|
||||
}
|
||||
// Fill attributes
|
||||
model.DB.Where(&user).First(&user)
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
switch req.Action {
|
||||
case "disable":
|
||||
user.Status = model.UserStatusDisabled
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法禁用超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "enable":
|
||||
user.Status = model.UserStatusEnabled
|
||||
case "delete":
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法删除超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.Delete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "promote":
|
||||
if myRole != model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "普通管理员用户无法提升其他用户为管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role >= model.RoleAdminUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = model.RoleAdminUser
|
||||
case "demote":
|
||||
if user.Role == model.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法降级超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role == model.RoleCommonUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是普通用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = model.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
clearUser := model.User{
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": clearUser,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func EmailBind(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
code := c.Query("code")
|
||||
if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码错误或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
user := model.User{
|
||||
Id: id,
|
||||
}
|
||||
err := user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Email = email
|
||||
// no need to check if this email already taken, because we have used verification code to check it
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role == model.RoleRootUser {
|
||||
config.RootUserEmail = email
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(ctx, req.Key, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": quota,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type adminTopUpRequest struct {
|
||||
UserId int `json:"user_id"`
|
||||
Quota int `json:"quota"`
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
|
||||
func AdminTopUp(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
req := adminTopUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Remark == "" {
|
||||
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
|
||||
}
|
||||
model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user