diff --git a/constant/context_key.go b/constant/context_key.go index 4eaf3d00..32dd9617 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -11,7 +11,6 @@ const ( ContextKeyTokenKey ContextKey = "token_key" ContextKeyTokenId ContextKey = "token_id" ContextKeyTokenGroup ContextKey = "token_group" - ContextKeyTokenAllowIps ContextKey = "allow_ips" ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" ContextKeyTokenModelLimit ContextKey = "token_model_limit" diff --git a/middleware/auth.go b/middleware/auth.go index 72900f83..5f6e5d43 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,7 +4,10 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/model" + "one-api/setting" + "one-api/setting/ratio_setting" "strconv" "strings" @@ -234,6 +237,16 @@ func TokenAuth() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return } + + allowIpsMap := token.GetIpLimitsMap() + if len(allowIpsMap) != 0 { + clientIp := c.ClientIP() + if _, ok := allowIpsMap[clientIp]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") + return + } + } + userCache, err := model.GetUserCache(token.UserId) if err != nil { abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error()) @@ -247,6 +260,25 @@ func TokenAuth() func(c *gin.Context) { userCache.WriteContext(c) + userGroup := userCache.Group + tokenGroup := token.Group + if tokenGroup != "" { + // check common.UserUsableGroups[userGroup] + if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) + return + } + // check group in common.GroupRatio + if !ratio_setting.ContainsGroupRatio(tokenGroup) { + if tokenGroup != "auto" { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) + return + } + } + userGroup = tokenGroup + } + common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) + err = SetupContextForToken(c, token, parts...) if err != nil { return @@ -273,7 +305,6 @@ func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) e } else { c.Set("token_model_limit_enabled", false) } - c.Set("allow_ips", token.GetIpLimitsMap()) c.Set("token_group", token.Group) if len(parts) > 1 { if model.IsAdmin(token.UserId) { diff --git a/middleware/distributor.go b/middleware/distributor.go index c7a55f4c..5fae6322 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -10,7 +10,6 @@ import ( "one-api/model" relayconstant "one-api/relay/constant" "one-api/service" - "one-api/setting" "one-api/setting/ratio_setting" "one-api/types" "strconv" @@ -27,14 +26,6 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps) - if len(allowIpsMap) != 0 { - clientIp := c.ClientIP() - if _, ok := allowIpsMap[clientIp]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中") - return - } - } var channel *model.Channel channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) modelRequest, shouldSelectChannel, err := getModelRequest(c) @@ -42,24 +33,6 @@ func Distribute() func(c *gin.Context) { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) - tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) - if tokenGroup != "" { - // check common.UserUsableGroups[userGroup] - if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) - return - } - // check group in common.GroupRatio - if !ratio_setting.ContainsGroupRatio(tokenGroup) { - if tokenGroup != "auto" { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) - return - } - } - userGroup = tokenGroup - } - common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -81,22 +54,21 @@ func Distribute() func(c *gin.Context) { modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) - var tokenModelLimit map[string]bool - if ok { - tokenModelLimit = s.(map[string]bool) - } else { - tokenModelLimit = map[string]bool{} - } - if tokenModelLimit != nil { - if _, ok := tokenModelLimit[modelRequest.Model]; !ok { - abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) - return - } - } else { + if !ok { // token model limit is empty, all models are not allowed abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型") return } + var tokenModelLimit map[string]bool + tokenModelLimit, ok = s.(map[string]bool) + if !ok { + tokenModelLimit = map[string]bool{} + } + matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-* + if _, ok := tokenModelLimit[matchName]; !ok { + abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model) + return + } } if shouldSelectChannel { @@ -105,6 +77,7 @@ func Distribute() func(c *gin.Context) { return } var selectGroup string + userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { showGroup := userGroup diff --git a/model/channel_cache.go b/model/channel_cache.go index 6ca23cf9..90bd2ad1 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -7,6 +7,7 @@ import ( "one-api/common" "one-api/constant" "one-api/setting" + "one-api/setting/ratio_setting" "sort" "strings" "sync" @@ -128,12 +129,7 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, } func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { - if strings.HasPrefix(model, "gpt-4-gizmo") { - model = "gpt-4-gizmo-*" - } - if strings.HasPrefix(model, "gpt-4o-gizmo") { - model = "gpt-4o-gizmo-*" - } + model = ratio_setting.FormatMatchingModelName(model) // if memory cache is disabled, get channel directly from database if !common.MemoryCacheEnabled { diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index be6dd6b9..647cc1f4 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -335,12 +335,8 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { modelPriceMapMutex.RLock() defer modelPriceMapMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + name = FormatMatchingModelName(name) + price, ok := modelPriceMap[name] if !ok { if printErr { @@ -374,11 +370,8 @@ func GetModelRatio(name string) (float64, bool, string) { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() - name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") - name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } + name = FormatMatchingModelName(name) + ratio, ok := modelRatioMap[name] if !ok { return 37.5, operation_setting.SelfUseModeEnabled, name @@ -429,12 +422,9 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { func GetCompletionRatio(name string) float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } + + name = FormatMatchingModelName(name) + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -664,3 +654,16 @@ func GetCompletionRatioCopy() map[string]float64 { } return copyMap } + +// 转换模型名,减少渠道必须配置各种带参数模型 +func FormatMatchingModelName(name string) string { + name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") + name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") + if strings.HasPrefix(name, "gpt-4-gizmo") { + name = "gpt-4-gizmo-*" + } + if strings.HasPrefix(name, "gpt-4o-gizmo") { + name = "gpt-4o-gizmo-*" + } + return name +}