refactor: 调整模型匹配

This commit is contained in:
Xyfacai
2025-08-06 20:09:22 +08:00
parent 4445e5891f
commit 0c0caad827
5 changed files with 66 additions and 64 deletions

View File

@@ -4,7 +4,10 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
allowIpsMap := token.GetIpLimitsMap()
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
userGroup := userCache.Group
tokenGroup := token.Group
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {

View File

@@ -10,7 +10,6 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/types"
"strconv"
@@ -27,14 +26,6 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
var channel *model.Channel
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) {
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
if !ok {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
var tokenModelLimit map[string]bool
tokenModelLimit, ok = s.(map[string]bool)
if !ok {
tokenModelLimit = map[string]bool{}
}
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
if _, ok := tokenModelLimit[matchName]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
}
if shouldSelectChannel {
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
return
}
var selectGroup string
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
showGroup := userGroup