refactor: migrate group ratio and user usable groups logic to new setting package
- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services. - Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability. - Updated related functions to ensure consistent behavior with the new setting package integration.
This commit is contained in:
@@ -3,13 +3,13 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.CacheGetUserGroup(userId)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := common.GetUserUsableGroups(userGroup)
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if _, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = userUsableGroups[groupName]
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = common.CheckGroupRatio(option.Value)
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
@@ -11,7 +12,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range common.GroupRatio {
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -22,9 +23,9 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = common.GetUserUsableGroups(group)
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range common.GroupRatio {
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !common.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/model"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) {
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if _, ok := common.GroupRatio[tokenGroup]; !ok {
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,8 +87,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "ModelRatio":
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = common.UpdateUserUsableGroupsByJSONString(value)
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = common.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
|
||||
@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
|
||||
@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
modelPrice = 0.0025 * modelRatio
|
||||
}
|
||||
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
|
||||
@@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
@@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
|
||||
@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio := common.GetModelRatio(modelName)
|
||||
|
||||
ratio := groupRatio * modelRatio
|
||||
|
||||
@@ -1,33 +1,47 @@
|
||||
package common
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var GroupRatio = map[string]float64{
|
||||
var groupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
|
||||
func GetGroupRatioCopy() map[string]float64 {
|
||||
groupRatioCopy := make(map[string]float64)
|
||||
for k, v := range groupRatio {
|
||||
groupRatioCopy[k] = v
|
||||
}
|
||||
return groupRatioCopy
|
||||
}
|
||||
|
||||
func ContainsGroupRatio(name string) bool {
|
||||
_, ok := groupRatio[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func GroupRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(GroupRatio)
|
||||
jsonBytes, err := json.Marshal(groupRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
common.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
GroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
|
||||
groupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &groupRatio)
|
||||
}
|
||||
|
||||
func GetGroupRatio(name string) float64 {
|
||||
ratio, ok := GroupRatio[name]
|
||||
ratio, ok := groupRatio[name]
|
||||
if !ok {
|
||||
SysError("group ratio not found: " + name)
|
||||
common.SysError("group ratio not found: " + name)
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
@@ -1,7 +1,8 @@
|
||||
package common
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var UserUsableGroups = map[string]string{
|
||||
@@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{
|
||||
func UserUsableGroups2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(UserUsableGroups)
|
||||
if err != nil {
|
||||
SysError("error marshalling user groups: " + err.Error())
|
||||
common.SysError("error marshalling user groups: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
Reference in New Issue
Block a user