refactor: migrate group ratio and user usable groups logic to new setting package

- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services.
- Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability.
- Updated related functions to ensure consistent behavior with the new setting package integration.
This commit is contained in:
CalciumIon
2024-12-25 19:31:12 +08:00
parent b3576f24ef
commit 4fc1fe318e
16 changed files with 57 additions and 34 deletions

View File

@@ -3,13 +3,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common"
"one-api/model" "one-api/model"
"one-api/setting"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio { for groupName, _ := range setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
userGroup := "" userGroup := ""
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId) userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio { for groupName, _ := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use // UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup) userUsableGroups := setting.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok { if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName] usableGroups[groupName] = userUsableGroups[groupName]
} }

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "GroupRatio": case "GroupRatio":
err = common.CheckGroupRatio(option.Value) err = setting.CheckGroupRatio(option.Value)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting"
) )
func GetPricing(c *gin.Context) { func GetPricing(c *gin.Context) {
@@ -11,7 +12,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id") userId, exists := c.Get("id")
usableGroup := map[string]string{} usableGroup := map[string]string{}
groupRatio := map[string]float64{} groupRatio := map[string]float64{}
for s, f := range common.GroupRatio { for s, f := range setting.GetGroupRatioCopy() {
groupRatio[s] = f groupRatio[s] = f
} }
var group string var group string
@@ -22,9 +23,9 @@ func GetPricing(c *gin.Context) {
} }
} }
usableGroup = common.GetUserUsableGroups(group) usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup // check groupRatio contains usableGroup
for group := range common.GroupRatio { for group := range setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok { if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group) delete(groupRatio, group)
} }

View File

@@ -17,6 +17,7 @@ import (
"one-api/relay/constant" "one-api/relay/constant"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
"strings" "strings"
) )
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
if group == "" { if group == "" {
group = userGroup group = userGroup
} else { } else {
if !common.GroupInUserUsableGroups(group) && group != userGroup { if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden) openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return return
} }

View File

@@ -10,6 +10,7 @@ import (
"one-api/model" "one-api/model"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) {
tokenGroup := c.GetString("token_group") tokenGroup := c.GetString("token_group")
if tokenGroup != "" { if tokenGroup != "" {
// check common.UserUsableGroups[userGroup] // check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup)) abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return return
} }
// check group in common.GroupRatio // check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok { if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return return
} }

View File

@@ -87,8 +87,8 @@ func InitOptionMap() {
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups": case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value) err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice": case "ModelPrice":

View File

@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
modelRatio := common.GetModelRatio(audioRequest.Model) modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)

View File

@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
modelPrice = 0.0025 * modelRatio modelPrice = 0.0025 * modelRatio
} }
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
sizeRatio := 1.0 sizeRatio := 1.0

View File

@@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
modelPrice = defaultPrice modelPrice = defaultPrice
} }
} }
groupRatio := common.GetGroupRatio(group) groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
@@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
modelPrice = defaultPrice modelPrice = defaultPrice
} }
} }
groupRatio := common.GetGroupRatio(group) groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {

View File

@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
relayInfo.UpstreamModelName = textRequest.Model relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int var preConsumedQuota int
var ratio float64 var ratio float64

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getRerankPromptToken(rerankRequest dto.RerankRequest) int { func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
relayInfo.UpstreamModelName = rerankRequest.Model relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false) modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int var preConsumedQuota int
var ratio float64 var ratio float64

View File

@@ -16,6 +16,7 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
) )
/* /*
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
} }
// 预扣 // 预扣
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil { if err != nil {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting"
) )
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
} }
//relayInfo.UpstreamModelName = textRequest.Model //relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false) modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int var preConsumedQuota int
var ratio float64 var ratio float64

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting"
"strings" "strings"
"time" "time"
) )
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
completionRatio := common.GetCompletionRatio(modelName) completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName) audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group) groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName) modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio ratio := groupRatio * modelRatio

View File

@@ -1,33 +1,47 @@
package common package setting
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"one-api/common"
) )
var GroupRatio = map[string]float64{ var groupRatio = map[string]float64{
"default": 1, "default": 1,
"vip": 1, "vip": 1,
"svip": 1, "svip": 1,
} }
func GetGroupRatioCopy() map[string]float64 {
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
}
return groupRatioCopy
}
func ContainsGroupRatio(name string) bool {
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string { func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio) jsonBytes, err := json.Marshal(groupRatio)
if err != nil { if err != nil {
SysError("error marshalling model ratio: " + err.Error()) common.SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }
func UpdateGroupRatioByJSONString(jsonStr string) error { func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64) groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio) return json.Unmarshal([]byte(jsonStr), &groupRatio)
} }
func GetGroupRatio(name string) float64 { func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name] ratio, ok := groupRatio[name]
if !ok { if !ok {
SysError("group ratio not found: " + name) common.SysError("group ratio not found: " + name)
return 1 return 1
} }
return ratio return ratio

View File

@@ -1,7 +1,8 @@
package common package setting
import ( import (
"encoding/json" "encoding/json"
"one-api/common"
) )
var UserUsableGroups = map[string]string{ var UserUsableGroups = map[string]string{
@@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{
func UserUsableGroups2JSONString() string { func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups) jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil { if err != nil {
SysError("error marshalling user groups: " + err.Error()) common.SysError("error marshalling user groups: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }