first commit: one-api base code + SAAS plan document
This commit is contained in:
167
middleware/auth.go
Normal file
167
middleware/auth.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func authHelper(c *gin.Context, minRole int) {
|
||||
session := sessions.Default(c)
|
||||
username := session.Get("username")
|
||||
role := session.Get("role")
|
||||
id := session.Get("id")
|
||||
status := session.Get("status")
|
||||
if username == nil {
|
||||
// Check access token
|
||||
accessToken := c.Request.Header.Get("Authorization")
|
||||
if accessToken == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,未登录且未提供 access token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
user := model.ValidateAccessToken(accessToken)
|
||||
if user != nil && user.Username != "" {
|
||||
// Token is valid
|
||||
username = user.Username
|
||||
role = user.Role
|
||||
id = user.Id
|
||||
status = user.Status
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,access token 无效",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已被封禁",
|
||||
})
|
||||
session := sessions.Default(c)
|
||||
session.Clear()
|
||||
_ = session.Save()
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if role.(int) < minRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权进行此操作,权限不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Set("username", username)
|
||||
c.Set("role", role)
|
||||
c.Set("id", id)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
func UserAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, model.RoleCommonUser)
|
||||
}
|
||||
}
|
||||
|
||||
func AdminAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, model.RoleAdminUser)
|
||||
}
|
||||
}
|
||||
|
||||
func RootAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, model.RoleRootUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil && shouldCheckModel(c) {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set(ctxkey.RequestModel, requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set(ctxkey.Id, token.UserId)
|
||||
c.Set(ctxkey.TokenId, token.Id)
|
||||
c.Set(ctxkey.TokenName, token.Name)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// set channel id for proxy relay
|
||||
if channelId := c.Param("channelid"); channelId != "" {
|
||||
c.Set(ctxkey.SpecificChannelId, channelId)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldCheckModel(c *gin.Context) bool {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
16
middleware/cache.go
Normal file
16
middleware/cache.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Cache() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.RequestURI == "/" {
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
} else {
|
||||
c.Header("Cache-Control", "max-age=604800") // one week
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
15
middleware/cors.go
Normal file
15
middleware/cors.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowAllOrigins = true
|
||||
config.AllowCredentials = true
|
||||
config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||
config.AllowHeaders = []string{"*"}
|
||||
return cors.New(config)
|
||||
}
|
||||
102
middleware/distributor.go
Normal file
102
middleware/distributor.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model" form:"model"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set(ctxkey.Group, userGroup)
|
||||
var requestModel string
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get(ctxkey.SpecificChannelId)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
requestModel = c.GetString(ctxkey.RequestModel)
|
||||
var err error
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||
if channel != nil {
|
||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
return
|
||||
}
|
||||
}
|
||||
logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
|
||||
SetupContextForSelectedChannel(c, channel, requestModel)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set(ctxkey.Channel, channel.Type)
|
||||
c.Set(ctxkey.ChannelId, channel.Id)
|
||||
c.Set(ctxkey.ChannelName, channel.Name)
|
||||
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
|
||||
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
|
||||
}
|
||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||
cfg, _ := channel.LoadConfig()
|
||||
// this is for backward compatibility
|
||||
if channel.Other != nil {
|
||||
switch channel.Type {
|
||||
case channeltype.Azure:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.Xunfei:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.Gemini:
|
||||
if cfg.APIVersion == "" {
|
||||
cfg.APIVersion = *channel.Other
|
||||
}
|
||||
case channeltype.AIProxyLibrary:
|
||||
if cfg.LibraryID == "" {
|
||||
cfg.LibraryID = *channel.Other
|
||||
}
|
||||
case channeltype.Ali:
|
||||
if cfg.Plugin == "" {
|
||||
cfg.Plugin = *channel.Other
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(ctxkey.Config, cfg)
|
||||
}
|
||||
27
middleware/gzip.go
Normal file
27
middleware/gzip.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GzipDecodeMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.GetHeader("Content-Encoding") == "gzip" {
|
||||
gzipReader, err := gzip.NewReader(c.Request.Body)
|
||||
if err != nil {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
// Replace the request body with the decompressed data
|
||||
c.Request.Body = io.NopCloser(gzipReader)
|
||||
}
|
||||
|
||||
// Continue processing the request
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
25
middleware/language.go
Normal file
25
middleware/language.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/i18n"
|
||||
)
|
||||
|
||||
func Language() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
lang := c.GetHeader("Accept-Language")
|
||||
if lang == "" {
|
||||
lang = "en"
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(lang), "zh") {
|
||||
lang = "zh-CN"
|
||||
} else {
|
||||
lang = "en"
|
||||
}
|
||||
c.Set(i18n.ContextKey, lang)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
25
middleware/logger.go
Normal file
25
middleware/logger.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
func SetUpLogger(server *gin.Engine) {
|
||||
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var requestID string
|
||||
if param.Keys != nil {
|
||||
requestID = param.Keys[helper.RequestIdKey].(string)
|
||||
}
|
||||
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
requestID,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
param.Method,
|
||||
param.Path,
|
||||
)
|
||||
}))
|
||||
}
|
||||
111
middleware/rate-limit.go
Normal file
111
middleware/rate-limit.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
)
|
||||
|
||||
var timeFormat = "2006-01-02T15:04:05.000Z"
|
||||
|
||||
var inMemoryRateLimiter common.InMemoryRateLimiter
|
||||
|
||||
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
key := "rateLimit:" + mark + c.ClientIP()
|
||||
listLength, err := rdb.LLen(ctx, key).Result()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if listLength < int64(maxRequestNum) {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||
} else {
|
||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
nowTimeStr := time.Now().Format(timeFormat)
|
||||
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// time.Since will return negative number!
|
||||
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
||||
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
||||
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
} else {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
||||
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
||||
key := mark + c.ClientIP()
|
||||
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
||||
if maxRequestNum == 0 || config.DebugEnabled {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
return func(c *gin.Context) {
|
||||
redisRateLimiter(c, maxRequestNum, duration, mark)
|
||||
}
|
||||
} else {
|
||||
// It's safe to call multi times.
|
||||
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration)
|
||||
return func(c *gin.Context) {
|
||||
memoryRateLimiter(c, maxRequestNum, duration, mark)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GlobalWebRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW")
|
||||
}
|
||||
|
||||
func GlobalAPIRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA")
|
||||
}
|
||||
|
||||
func CriticalRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT")
|
||||
}
|
||||
|
||||
func DownloadRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW")
|
||||
}
|
||||
|
||||
func UploadRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP")
|
||||
}
|
||||
33
middleware/recover.go
Normal file
33
middleware/recover.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func RelayPanicRecover() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
ctx := c.Request.Context()
|
||||
logger.Errorf(ctx, fmt.Sprintf("panic detected: %v", err))
|
||||
logger.Errorf(ctx, fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||
logger.Errorf(ctx, fmt.Sprintf("request: %s %s", c.Request.Method, c.Request.URL.Path))
|
||||
body, _ := common.GetRequestBody(c)
|
||||
logger.Errorf(ctx, fmt.Sprintf("request body: %s", string(body)))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit an issue with the related log here: https://github.com/songquanpeng/one-api", err),
|
||||
"type": "one_api_panic",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
18
middleware/request-id.go
Normal file
18
middleware/request-id.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
)
|
||||
|
||||
func RequestId() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
id := helper.GenRequestID()
|
||||
c.Set(helper.RequestIdKey, id)
|
||||
ctx := helper.SetRequestID(c.Request.Context(), id)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Header(helper.RequestIdKey, id)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
81
middleware/turnstile-check.go
Normal file
81
middleware/turnstile-check.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type turnstileCheckResponse struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func TurnstileCheck() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if config.TurnstileCheckEnabled {
|
||||
session := sessions.Default(c)
|
||||
turnstileChecked := session.Get("turnstile")
|
||||
if turnstileChecked != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response := c.Query("turnstile")
|
||||
if response == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Turnstile token 为空",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
|
||||
"secret": {config.TurnstileSecretKey},
|
||||
"response": {response},
|
||||
"remoteip": {c.ClientIP()},
|
||||
})
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer rawRes.Body.Close()
|
||||
var res turnstileCheckResponse
|
||||
err = json.NewDecoder(rawRes.Body).Decode(&res)
|
||||
if err != nil {
|
||||
logger.SysError(err.Error())
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if !res.Success {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Turnstile 校验失败,请刷新重试!",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
session.Set("turnstile", true)
|
||||
err = session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
60
middleware/utils.go
Normal file
60
middleware/utils.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
return modelRequest.Model, nil
|
||||
}
|
||||
|
||||
func isModelInList(modelName string, models string) bool {
|
||||
modelList := strings.Split(models, ",")
|
||||
for _, model := range modelList {
|
||||
if modelName == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user