refactor: migrate group ratio and user usable groups logic to new setting package
- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services. - Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability. - Updated related functions to ensure consistent behavior with the new setting package integration.
This commit is contained in:
@@ -3,13 +3,13 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range common.GroupRatio {
|
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
|
|||||||
userGroup := ""
|
userGroup := ""
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, _ = model.CacheGetUserGroup(userId)
|
userGroup, _ = model.CacheGetUserGroup(userId)
|
||||||
for groupName, _ := range common.GroupRatio {
|
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||||
// UserUsableGroups contains the groups that the user can use
|
// UserUsableGroups contains the groups that the user can use
|
||||||
userUsableGroups := common.GetUserUsableGroups(userGroup)
|
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||||
if _, ok := userUsableGroups[groupName]; ok {
|
if _, ok := userUsableGroups[groupName]; ok {
|
||||||
usableGroups[groupName] = userUsableGroups[groupName]
|
usableGroups[groupName] = userUsableGroups[groupName]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = common.CheckGroupRatio(option.Value)
|
err = setting.CheckGroupRatio(option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPricing(c *gin.Context) {
|
func GetPricing(c *gin.Context) {
|
||||||
@@ -11,7 +12,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
userId, exists := c.Get("id")
|
userId, exists := c.Get("id")
|
||||||
usableGroup := map[string]string{}
|
usableGroup := map[string]string{}
|
||||||
groupRatio := map[string]float64{}
|
groupRatio := map[string]float64{}
|
||||||
for s, f := range common.GroupRatio {
|
for s, f := range setting.GetGroupRatioCopy() {
|
||||||
groupRatio[s] = f
|
groupRatio[s] = f
|
||||||
}
|
}
|
||||||
var group string
|
var group string
|
||||||
@@ -22,9 +23,9 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usableGroup = common.GetUserUsableGroups(group)
|
usableGroup = setting.GetUserUsableGroups(group)
|
||||||
// check groupRatio contains usableGroup
|
// check groupRatio contains usableGroup
|
||||||
for group := range common.GroupRatio {
|
for group := range setting.GetGroupRatioCopy() {
|
||||||
if _, ok := usableGroup[group]; !ok {
|
if _, ok := usableGroup[group]; !ok {
|
||||||
delete(groupRatio, group)
|
delete(groupRatio, group)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
|
|||||||
if group == "" {
|
if group == "" {
|
||||||
group = userGroup
|
group = userGroup
|
||||||
} else {
|
} else {
|
||||||
if !common.GroupInUserUsableGroups(group) && group != userGroup {
|
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) {
|
|||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := c.GetString("token_group")
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
// check common.UserUsableGroups[userGroup]
|
// check common.UserUsableGroups[userGroup]
|
||||||
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// check group in common.GroupRatio
|
// check group in common.GroupRatio
|
||||||
if _, ok := common.GroupRatio[tokenGroup]; !ok {
|
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,8 +87,8 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||||
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
|
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||||
common.OptionMap["ChatLink"] = common.ChatLink
|
common.OptionMap["ChatLink"] = common.ChatLink
|
||||||
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
case "ModelRatio":
|
case "ModelRatio":
|
||||||
err = common.UpdateModelRatioByJSONString(value)
|
err = common.UpdateModelRatioByJSONString(value)
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = common.UpdateGroupRatioByJSONString(value)
|
err = setting.UpdateGroupRatioByJSONString(value)
|
||||||
case "UserUsableGroups":
|
case "UserUsableGroups":
|
||||||
err = common.UpdateUserUsableGroupsByJSONString(value)
|
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||||
case "CompletionRatio":
|
case "CompletionRatio":
|
||||||
err = common.UpdateCompletionRatioByJSONString(value)
|
err = common.UpdateCompletionRatioByJSONString(value)
|
||||||
case "ModelPrice":
|
case "ModelPrice":
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
modelPrice = 0.0025 * modelRatio
|
modelPrice = 0.0025 * modelRatio
|
||||||
}
|
}
|
||||||
|
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
|
|
||||||
sizeRatio := 1.0
|
sizeRatio := 1.0
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
modelPrice = defaultPrice
|
modelPrice = defaultPrice
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
groupRatio := common.GetGroupRatio(group)
|
groupRatio := setting.GetGroupRatio(group)
|
||||||
ratio := modelPrice * groupRatio
|
ratio := modelPrice * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
modelPrice = defaultPrice
|
modelPrice = defaultPrice
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
groupRatio := common.GetGroupRatio(group)
|
groupRatio := setting.GetGroupRatio(group)
|
||||||
ratio := modelPrice * groupRatio
|
ratio := modelPrice * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
relayInfo.UpstreamModelName = textRequest.Model
|
relayInfo.UpstreamModelName = textRequest.Model
|
||||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
var ratio float64
|
var ratio float64
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
|
|
||||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
var ratio float64
|
var ratio float64
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 预扣
|
// 预扣
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
ratio := modelPrice * groupRatio
|
ratio := modelPrice * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||||
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
|||||||
}
|
}
|
||||||
//relayInfo.UpstreamModelName = textRequest.Model
|
//relayInfo.UpstreamModelName = textRequest.Model
|
||||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
|
|
||||||
var preConsumedQuota int
|
var preConsumedQuota int
|
||||||
var ratio float64
|
var ratio float64
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
|||||||
completionRatio := common.GetCompletionRatio(modelName)
|
completionRatio := common.GetCompletionRatio(modelName)
|
||||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||||
modelRatio := common.GetModelRatio(modelName)
|
modelRatio := common.GetModelRatio(modelName)
|
||||||
|
|
||||||
ratio := groupRatio * modelRatio
|
ratio := groupRatio * modelRatio
|
||||||
|
|||||||
@@ -1,33 +1,47 @@
|
|||||||
package common
|
package setting
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var GroupRatio = map[string]float64{
|
var groupRatio = map[string]float64{
|
||||||
"default": 1,
|
"default": 1,
|
||||||
"vip": 1,
|
"vip": 1,
|
||||||
"svip": 1,
|
"svip": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetGroupRatioCopy() map[string]float64 {
|
||||||
|
groupRatioCopy := make(map[string]float64)
|
||||||
|
for k, v := range groupRatio {
|
||||||
|
groupRatioCopy[k] = v
|
||||||
|
}
|
||||||
|
return groupRatioCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContainsGroupRatio(name string) bool {
|
||||||
|
_, ok := groupRatio[name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func GroupRatio2JSONString() string {
|
func GroupRatio2JSONString() string {
|
||||||
jsonBytes, err := json.Marshal(GroupRatio)
|
jsonBytes, err := json.Marshal(groupRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model ratio: " + err.Error())
|
common.SysError("error marshalling model ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||||
GroupRatio = make(map[string]float64)
|
groupRatio = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
|
return json.Unmarshal([]byte(jsonStr), &groupRatio)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGroupRatio(name string) float64 {
|
func GetGroupRatio(name string) float64 {
|
||||||
ratio, ok := GroupRatio[name]
|
ratio, ok := groupRatio[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
SysError("group ratio not found: " + name)
|
common.SysError("group ratio not found: " + name)
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return ratio
|
return ratio
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
package common
|
package setting
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var UserUsableGroups = map[string]string{
|
var UserUsableGroups = map[string]string{
|
||||||
@@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{
|
|||||||
func UserUsableGroups2JSONString() string {
|
func UserUsableGroups2JSONString() string {
|
||||||
jsonBytes, err := json.Marshal(UserUsableGroups)
|
jsonBytes, err := json.Marshal(UserUsableGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling user groups: " + err.Error())
|
common.SysError("error marshalling user groups: " + err.Error())
|
||||||
}
|
}
|
||||||
return string(jsonBytes)
|
return string(jsonBytes)
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user