refactor: 调整模型匹配
This commit is contained in:
@@ -11,7 +11,6 @@ const (
|
|||||||
ContextKeyTokenKey ContextKey = "token_key"
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
ContextKeyTokenId ContextKey = "token_id"
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
ContextKeyTokenGroup ContextKey = "token_group"
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
|
||||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowIpsMap := token.GetIpLimitsMap()
|
||||||
|
if len(allowIpsMap) != 0 {
|
||||||
|
clientIp := c.ClientIP()
|
||||||
|
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
userCache, err := model.GetUserCache(token.UserId)
|
userCache, err := model.GetUserCache(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||||
@@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
|
|
||||||
userCache.WriteContext(c)
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
|
userGroup := userCache.Group
|
||||||
|
tokenGroup := token.Group
|
||||||
|
if tokenGroup != "" {
|
||||||
|
// check common.UserUsableGroups[userGroup]
|
||||||
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check group in common.GroupRatio
|
||||||
|
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||||
|
if tokenGroup != "auto" {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userGroup = tokenGroup
|
||||||
|
}
|
||||||
|
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||||
|
|
||||||
err = SetupContextForToken(c, token, parts...)
|
err = SetupContextForToken(c, token, parts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e
|
|||||||
} else {
|
} else {
|
||||||
c.Set("token_model_limit_enabled", false)
|
c.Set("token_model_limit_enabled", false)
|
||||||
}
|
}
|
||||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
|
||||||
c.Set("token_group", token.Group)
|
c.Set("token_group", token.Group)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
|
||||||
"one-api/setting/ratio_setting"
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -27,14 +26,6 @@ type ModelRequest struct {
|
|||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
|
||||||
if len(allowIpsMap) != 0 {
|
|
||||||
clientIp := c.ClientIP()
|
|
||||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||||
@@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
|
||||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
|
||||||
if tokenGroup != "" {
|
|
||||||
// check common.UserUsableGroups[userGroup]
|
|
||||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// check group in common.GroupRatio
|
|
||||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
|
||||||
if tokenGroup != "auto" {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
userGroup = tokenGroup
|
|
||||||
}
|
|
||||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
if !ok {
|
||||||
if ok {
|
|
||||||
tokenModelLimit = s.(map[string]bool)
|
|
||||||
} else {
|
|
||||||
tokenModelLimit = map[string]bool{}
|
|
||||||
}
|
|
||||||
if tokenModelLimit != nil {
|
|
||||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// token model limit is empty, all models are not allowed
|
// token model limit is empty, all models are not allowed
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var tokenModelLimit map[string]bool
|
||||||
|
tokenModelLimit, ok = s.(map[string]bool)
|
||||||
|
if !ok {
|
||||||
|
tokenModelLimit = map[string]bool{}
|
||||||
|
}
|
||||||
|
matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
|
||||||
|
if _, ok := tokenModelLimit[matchName]; !ok {
|
||||||
|
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldSelectChannel {
|
if shouldSelectChannel {
|
||||||
@@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var selectGroup string
|
var selectGroup string
|
||||||
|
userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
showGroup := userGroup
|
showGroup := userGroup
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
model = ratio_setting.FormatMatchingModelName(model)
|
||||||
model = "gpt-4-gizmo-*"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(model, "gpt-4o-gizmo") {
|
|
||||||
model = "gpt-4o-gizmo-*"
|
|
||||||
}
|
|
||||||
|
|
||||||
// if memory cache is disabled, get channel directly from database
|
// if memory cache is disabled, get channel directly from database
|
||||||
if !common.MemoryCacheEnabled {
|
if !common.MemoryCacheEnabled {
|
||||||
|
|||||||
@@ -335,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|||||||
modelPriceMapMutex.RLock()
|
modelPriceMapMutex.RLock()
|
||||||
defer modelPriceMapMutex.RUnlock()
|
defer modelPriceMapMutex.RUnlock()
|
||||||
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
name = FormatMatchingModelName(name)
|
||||||
name = "gpt-4-gizmo-*"
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
|
||||||
name = "gpt-4o-gizmo-*"
|
|
||||||
}
|
|
||||||
price, ok := modelPriceMap[name]
|
price, ok := modelPriceMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
if printErr {
|
if printErr {
|
||||||
@@ -374,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) {
|
|||||||
modelRatioMapMutex.RLock()
|
modelRatioMapMutex.RLock()
|
||||||
defer modelRatioMapMutex.RUnlock()
|
defer modelRatioMapMutex.RUnlock()
|
||||||
|
|
||||||
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
|
name = FormatMatchingModelName(name)
|
||||||
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
|
||||||
name = "gpt-4-gizmo-*"
|
|
||||||
}
|
|
||||||
ratio, ok := modelRatioMap[name]
|
ratio, ok := modelRatioMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
return 37.5, operation_setting.SelfUseModeEnabled, name
|
return 37.5, operation_setting.SelfUseModeEnabled, name
|
||||||
@@ -429,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
|||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
CompletionRatioMutex.RLock()
|
CompletionRatioMutex.RLock()
|
||||||
defer CompletionRatioMutex.RUnlock()
|
defer CompletionRatioMutex.RUnlock()
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
|
||||||
name = "gpt-4-gizmo-*"
|
name = FormatMatchingModelName(name)
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
|
||||||
name = "gpt-4o-gizmo-*"
|
|
||||||
}
|
|
||||||
if strings.Contains(name, "/") {
|
if strings.Contains(name, "/") {
|
||||||
if ratio, ok := CompletionRatio[name]; ok {
|
if ratio, ok := CompletionRatio[name]; ok {
|
||||||
return ratio
|
return ratio
|
||||||
@@ -664,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 {
|
|||||||
}
|
}
|
||||||
return copyMap
|
return copyMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 转换模型名,减少渠道必须配置各种带参数模型
|
||||||
|
func FormatMatchingModelName(name string) string {
|
||||||
|
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
|
||||||
|
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
||||||
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
|
name = "gpt-4-gizmo-*"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||||
|
name = "gpt-4o-gizmo-*"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user