Merge branch 'QuantumNous:main' into main
This commit is contained in:
@@ -2,12 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
@@ -216,10 +218,14 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
|
||||
key = strings.TrimSpace(key[7:])
|
||||
}
|
||||
if key == "" || key == "midjourney-proxy" {
|
||||
key = c.Request.Header.Get("mj-api-secret")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
|
||||
key = strings.TrimSpace(key[7:])
|
||||
}
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts = strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
@@ -240,13 +246,20 @@ func TokenAuth() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
allowIpsMap := token.GetIpLimitsMap()
|
||||
if len(allowIpsMap) != 0 {
|
||||
allowIps := token.GetIpLimits()
|
||||
if len(allowIps) > 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
logger.LogDebug(c, "Token has IP restrictions, checking client IP %s", clientIp)
|
||||
ip := net.ParseIP(clientIp)
|
||||
if ip == nil {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "无法解析客户端 IP 地址")
|
||||
return
|
||||
}
|
||||
if common.IsIpInCIDRList(ip, allowIps) == false {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||
return
|
||||
}
|
||||
logger.LogDebug(c, "Client IP %s passed the token IP restrictions check", clientIp)
|
||||
}
|
||||
|
||||
userCache, err := model.GetUserCache(token.UserId)
|
||||
@@ -307,7 +320,8 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("token_group", token.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, token.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyTokenCrossGroupRetry, token.CrossGroupRetry)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
|
||||
@@ -97,7 +97,12 @@ func Distribute() func(c *gin.Context) {
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup)
|
||||
}
|
||||
}
|
||||
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(c, usingGroup, modelRequest.Model, 0)
|
||||
channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{
|
||||
Ctx: c,
|
||||
ModelName: modelRequest.Model,
|
||||
TokenGroup: usingGroup,
|
||||
Retry: common.GetPointer(0),
|
||||
})
|
||||
if err != nil {
|
||||
showGroup := usingGroup
|
||||
if usingGroup == "auto" {
|
||||
@@ -157,7 +162,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
return nil, false, fmt.Errorf(mjErr.Description)
|
||||
return nil, false, fmt.Errorf("%s", mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
|
||||
@@ -5,32 +5,69 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type readCloser struct {
|
||||
io.Reader
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (rc *readCloser) Close() error {
|
||||
if rc.closeFn != nil {
|
||||
return rc.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DecompressRequestMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Body == nil || c.Request.Method == http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
maxMB := constant.MaxRequestBodyMB
|
||||
if maxMB <= 0 {
|
||||
maxMB = 32
|
||||
}
|
||||
maxBytes := int64(maxMB) << 20
|
||||
|
||||
origBody := c.Request.Body
|
||||
wrapMaxBytes := func(body io.ReadCloser) io.ReadCloser {
|
||||
return http.MaxBytesReader(c.Writer, body, maxBytes)
|
||||
}
|
||||
|
||||
switch c.GetHeader("Content-Encoding") {
|
||||
case "gzip":
|
||||
gzipReader, err := gzip.NewReader(c.Request.Body)
|
||||
gzipReader, err := gzip.NewReader(origBody)
|
||||
if err != nil {
|
||||
_ = origBody.Close()
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
// Replace the request body with the decompressed data
|
||||
c.Request.Body = io.NopCloser(gzipReader)
|
||||
// Replace the request body with the decompressed data, and enforce a max size (post-decompression).
|
||||
c.Request.Body = wrapMaxBytes(&readCloser{
|
||||
Reader: gzipReader,
|
||||
closeFn: func() error {
|
||||
_ = gzipReader.Close()
|
||||
return origBody.Close()
|
||||
},
|
||||
})
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
case "br":
|
||||
reader := brotli.NewReader(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(reader)
|
||||
reader := brotli.NewReader(origBody)
|
||||
c.Request.Body = wrapMaxBytes(&readCloser{
|
||||
Reader: reader,
|
||||
closeFn: func() error {
|
||||
return origBody.Close()
|
||||
},
|
||||
})
|
||||
c.Request.Header.Del("Content-Encoding")
|
||||
default:
|
||||
// Even for uncompressed bodies, enforce a max size to avoid huge request allocations.
|
||||
c.Request.Body = wrapMaxBytes(origBody)
|
||||
}
|
||||
|
||||
// Continue processing the request
|
||||
|
||||
Reference in New Issue
Block a user