refactor: 调整模型匹配

This commit is contained in:
Xyfacai
2025-08-06 20:09:22 +08:00
parent 4445e5891f
commit 0c0caad827
5 changed files with 66 additions and 64 deletions

View File

@@ -11,7 +11,6 @@ const (
ContextKeyTokenKey ContextKey = "token_key"
ContextKeyTokenId ContextKey = "token_id"
ContextKeyTokenGroup ContextKey = "token_group"
ContextKeyTokenAllowIps ContextKey = "allow_ips"
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"

View File

@@ -4,7 +4,10 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"one-api/setting/ratio_setting"
"strconv"
"strings"
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
allowIpsMap := token.GetIpLimitsMap()
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
userGroup := userCache.Group
tokenGroup := token.Group
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
err = SetupContextForToken(c, token, parts...)
if err != nil {
return
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
} else {
c.Set("token_model_limit_enabled", false)
}
c.Set("allow_ips", token.GetIpLimitsMap())
c.Set("token_group", token.Group)
if len(parts) > 1 {
if model.IsAdmin(token.UserId) {

View File

@@ -10,7 +10,6 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/types"
"strconv"
@@ -27,14 +26,6 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
if len(allowIpsMap) != 0 {
clientIp := c.ClientIP()
if _, ok := allowIpsMap[clientIp]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
return
}
}
var channel *model.Channel
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
modelRequest, shouldSelectChannel, err := getModelRequest(c)
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
if tokenGroup != "auto" {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
}
userGroup = tokenGroup
}
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) {
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
if !ok {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
var tokenModelLimit map[string]bool
tokenModelLimit, ok = s.(map[string]bool)
if !ok {
tokenModelLimit = map[string]bool{}
}
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
if _, ok := tokenModelLimit[matchName]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
}
if shouldSelectChannel {
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
return
}
var selectGroup string
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
showGroup := userGroup

View File

@@ -7,6 +7,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/setting"
"one-api/setting/ratio_setting"
"sort"
"strings"
"sync"
@@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
}
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
if strings.HasPrefix(model, "gpt-4o-gizmo") {
model = "gpt-4o-gizmo-*"
}
model = ratio_setting.FormatMatchingModelName(model)
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {

View File

@@ -335,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
name = FormatMatchingModelName(name)
price, ok := modelPriceMap[name]
if !ok {
if printErr {
@@ -374,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
name = FormatMatchingModelName(name)
ratio, ok := modelRatioMap[name]
if !ok {
return 37.5, operation_setting.SelfUseModeEnabled, name
@@ -429,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
func GetCompletionRatio(name string) float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
name = FormatMatchingModelName(name)
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -664,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 {
}
return copyMap
}
// 转换模型名,减少渠道必须配置各种带参数模型
func FormatMatchingModelName(name string) string {
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
if strings.HasPrefix(name, "gpt-4o-gizmo") {
name = "gpt-4o-gizmo-*"
}
return name
}