From 4fc1fe318ed7b1256d7a69ad40a42024a7b2fccd Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 25 Dec 2024 19:31:12 +0800 Subject: [PATCH] refactor: migrate group ratio and user usable groups logic to new setting package - Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services. - Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability. - Updated related functions to ensure consistent behavior with the new setting package integration. --- controller/group.go | 8 ++--- controller/option.go | 3 +- controller/pricing.go | 7 +++-- controller/relay.go | 3 +- middleware/distributor.go | 5 ++-- model/option.go | 8 ++--- relay/relay-audio.go | 2 +- relay/relay-image.go | 2 +- relay/relay-mj.go | 4 +-- relay/relay-text.go | 2 +- relay/relay_rerank.go | 3 +- relay/relay_task.go | 3 +- relay/websocket.go | 3 +- service/quota.go | 3 +- .../group-ratio.go => setting/group_ratio.go | 30 ++++++++++++++----- .../user_usable_group.go | 5 ++-- 16 files changed, 57 insertions(+), 34 deletions(-) rename common/group-ratio.go => setting/group_ratio.go (51%) rename common/user_groups.go => setting/user_usable_group.go (91%) diff --git a/controller/group.go b/controller/group.go index 9af07af0..c5fde769 100644 --- a/controller/group.go +++ b/controller/group.go @@ -3,13 +3,13 @@ package controller import ( "github.com/gin-gonic/gin" "net/http" - "one-api/common" "one-api/model" + "one-api/setting" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName, _ := range common.GroupRatio { + for groupName, _ := range setting.GetGroupRatioCopy() { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ @@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) { userGroup := "" userId := c.GetInt("id") userGroup, _ = model.CacheGetUserGroup(userId) - for groupName, _ := range common.GroupRatio { + for groupName, _ := range setting.GetGroupRatioCopy() { // UserUsableGroups contains the groups that the user can use - userUsableGroups := common.GetUserUsableGroups(userGroup) + userUsableGroups := setting.GetUserUsableGroups(userGroup) if _, ok := userUsableGroups[groupName]; ok { usableGroups[groupName] = userUsableGroups[groupName] } diff --git a/controller/option.go b/controller/option.go index b6165a88..c82fbd7e 100644 --- a/controller/option.go +++ b/controller/option.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/setting" "strings" "github.com/gin-gonic/gin" @@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) { return } case "GroupRatio": - err = common.CheckGroupRatio(option.Value) + err = setting.CheckGroupRatio(option.Value) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/pricing.go b/controller/pricing.go index 9862bb2f..36caff9d 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "one-api/common" "one-api/model" + "one-api/setting" ) func GetPricing(c *gin.Context) { @@ -11,7 +12,7 @@ func GetPricing(c *gin.Context) { userId, exists := c.Get("id") usableGroup := map[string]string{} groupRatio := map[string]float64{} - for s, f := range common.GroupRatio { + for s, f := range setting.GetGroupRatioCopy() { groupRatio[s] = f } var group string @@ -22,9 +23,9 @@ func GetPricing(c *gin.Context) { } } - usableGroup = common.GetUserUsableGroups(group) + usableGroup = setting.GetUserUsableGroups(group) // check groupRatio contains usableGroup - for group := range common.GroupRatio { + for group := range setting.GetGroupRatioCopy() { if _, ok := usableGroup[group]; !ok { delete(groupRatio, group) } diff --git a/controller/relay.go b/controller/relay.go index c2e7523d..5581d5bb 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -17,6 +17,7 @@ import ( "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/service" + "one-api/setting" "strings" ) @@ -83,7 +84,7 @@ func Playground(c *gin.Context) { if group == "" { group = userGroup } else { - if !common.GroupInUserUsableGroups(group) && group != userGroup { + if !setting.GroupInUserUsableGroups(group) && group != userGroup { openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) return } diff --git a/middleware/distributor.go b/middleware/distributor.go index 0af6f82e..0d5a8cac 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -10,6 +10,7 @@ import ( "one-api/model" relayconstant "one-api/relay/constant" "one-api/service" + "one-api/setting" "strconv" "strings" "time" @@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) { tokenGroup := c.GetString("token_group") if tokenGroup != "" { // check common.UserUsableGroups[userGroup] - if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { + if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) return } // check group in common.GroupRatio - if _, ok := common.GroupRatio[tokenGroup]; !ok { + if !setting.ContainsGroupRatio(tokenGroup) { abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) return } diff --git a/model/option.go b/model/option.go index fd6afa03..1daf40fb 100644 --- a/model/option.go +++ b/model/option.go @@ -87,8 +87,8 @@ func InitOptionMap() { common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() - common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString() + common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() + common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink @@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) { case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = common.UpdateGroupRatioByJSONString(value) + err = setting.UpdateGroupRatioByJSONString(value) case "UserUsableGroups": - err = common.UpdateUserUsableGroupsByJSONString(value) + err = setting.UpdateUserUsableGroupsByJSONString(value) case "CompletionRatio": err = common.UpdateCompletionRatioByJSONString(value) case "ModelPrice": diff --git a/relay/relay-audio.go b/relay/relay-audio.go index ff0df524..c9f54f82 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) diff --git a/relay/relay-image.go b/relay/relay-image.go index d4b0b700..5ec71611 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { modelPrice = 0.0025 * modelRatio } - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) sizeRatio := 1.0 diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 3ff309a0..8bc5c93a 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { modelPrice = defaultPrice } } - groupRatio := common.GetGroupRatio(group) + groupRatio := setting.GetGroupRatio(group) ratio := modelPrice * groupRatio userQuota, err := model.CacheGetUserQuota(userId) if err != nil { @@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons modelPrice = defaultPrice } } - groupRatio := common.GetGroupRatio(group) + groupRatio := setting.GetGroupRatio(group) ratio := modelPrice * groupRatio userQuota, err := model.CacheGetUserQuota(userId) if err != nil { diff --git a/relay/relay-text.go b/relay/relay-text.go index a23d0542..9a9e6974 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } relayInfo.UpstreamModelName = textRequest.Model modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) var preConsumedQuota int var ratio float64 diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index a627b780..e53e37d4 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -10,6 +10,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting" ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { @@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith relayInfo.UpstreamModelName = rerankRequest.Model modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) var preConsumedQuota int var ratio float64 diff --git a/relay/relay_task.go b/relay/relay_task.go index 5e5a5843..7b694a81 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -16,6 +16,7 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" + "one-api/setting" ) /* @@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) ratio := modelPrice * groupRatio userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { diff --git a/relay/websocket.go b/relay/websocket.go index 247169e3..c05e70a9 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -10,6 +10,7 @@ import ( "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" + "one-api/setting" ) //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { @@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi } //relayInfo.UpstreamModelName = textRequest.Model modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false) - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) var preConsumedQuota int var ratio float64 diff --git a/service/quota.go b/service/quota.go index dc908cd6..2e0cd4fb 100644 --- a/service/quota.go +++ b/service/quota.go @@ -9,6 +9,7 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" + "one-api/setting" "strings" "time" ) @@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag completionRatio := common.GetCompletionRatio(modelName) audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName) - groupRatio := common.GetGroupRatio(relayInfo.Group) + groupRatio := setting.GetGroupRatio(relayInfo.Group) modelRatio := common.GetModelRatio(modelName) ratio := groupRatio * modelRatio diff --git a/common/group-ratio.go b/setting/group_ratio.go similarity index 51% rename from common/group-ratio.go rename to setting/group_ratio.go index 673b447d..e715d0a8 100644 --- a/common/group-ratio.go +++ b/setting/group_ratio.go @@ -1,33 +1,47 @@ -package common +package setting import ( "encoding/json" "errors" + "one-api/common" ) -var GroupRatio = map[string]float64{ +var groupRatio = map[string]float64{ "default": 1, "vip": 1, "svip": 1, } +func GetGroupRatioCopy() map[string]float64 { + groupRatioCopy := make(map[string]float64) + for k, v := range groupRatio { + groupRatioCopy[k] = v + } + return groupRatioCopy +} + +func ContainsGroupRatio(name string) bool { + _, ok := groupRatio[name] + return ok +} + func GroupRatio2JSONString() string { - jsonBytes, err := json.Marshal(GroupRatio) + jsonBytes, err := json.Marshal(groupRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + common.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } func UpdateGroupRatioByJSONString(jsonStr string) error { - GroupRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &GroupRatio) + groupRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &groupRatio) } func GetGroupRatio(name string) float64 { - ratio, ok := GroupRatio[name] + ratio, ok := groupRatio[name] if !ok { - SysError("group ratio not found: " + name) + common.SysError("group ratio not found: " + name) return 1 } return ratio diff --git a/common/user_groups.go b/setting/user_usable_group.go similarity index 91% rename from common/user_groups.go rename to setting/user_usable_group.go index 41b53b3b..6135022e 100644 --- a/common/user_groups.go +++ b/setting/user_usable_group.go @@ -1,7 +1,8 @@ -package common +package setting import ( "encoding/json" + "one-api/common" ) var UserUsableGroups = map[string]string{ @@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(UserUsableGroups) if err != nil { - SysError("error marshalling user groups: " + err.Error()) + common.SysError("error marshalling user groups: " + err.Error()) } return string(jsonBytes) }