diff --git a/controller/group.go b/controller/group.go
index 2c725a4d..632b6cd5 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
@@ -34,6 +35,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
diff --git a/controller/misc.go b/controller/misc.go
index 33a41302..1caaf640 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -9,9 +9,9 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
- "one-api/setting/console_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -41,46 +41,47 @@ func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "min_topup": setting.MinTopUp,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "min_topup": setting.MinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
// 面板启用开关
- "api_info_enabled": cs.ApiInfoEnabled,
- "uptime_kuma_enabled": cs.UptimeKumaEnabled,
- "announcements_enabled": cs.AnnouncementsEnabled,
- "faq_enabled": cs.FAQEnabled,
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
diff --git a/controller/model.go b/controller/model.go
index df7e59a6..134217a3 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -2,7 +2,6 @@ package controller
import (
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +14,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -179,7 +181,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
- models := model.GetGroupModels(group)
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupModels(group)
+ }
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
diff --git a/controller/playground.go b/controller/playground.go
index a2b54790..37a5c7b0 100644
--- a/controller/playground.go
+++ b/controller/playground.go
@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
+
+ "github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,9 +58,9 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
- channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
+ channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
diff --git a/controller/relay.go b/controller/relay.go
index 1a875dbc..c1c45114 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
diff --git a/controller/user.go b/controller/user.go
index ecaf2583..e8ce3c3d 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1bfe1821..5d1c3641 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -49,8 +49,10 @@ func Distribute() func(c *gin.Context) {
}
// check group in common.GroupRatio
if !setting.ContainsGroupRatio(tokenGroup) {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
}
userGroup = tokenGroup
}
@@ -95,9 +97,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
+ var selectGroup string
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
diff --git a/model/cache.go b/model/cache.go
index e2f83e22..1d7d2f25 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -3,12 +3,16 @@ package model
import (
"errors"
"fmt"
+ "log"
"math/rand"
"one-api/common"
+ "one-api/setting"
"sort"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +79,39 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ log.Printf("autoGroup: %s", autoGroup)
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ log.Printf("selectGroup: %s", selectGroup)
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ if channel == nil {
+ return nil, group, errors.New("channel not found")
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
diff --git a/model/option.go b/model/option.go
index d1689cb7..89ab8506 100644
--- a/model/option.go
+++ b/model/option.go
@@ -76,6 +76,8 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -287,6 +289,10 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = value == "true"
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
diff --git a/relay/helper/price.go b/relay/helper/price.go
index 1b52bf37..6ecebac5 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -2,6 +2,7 @@ package helper
import (
"fmt"
+ "log"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
@@ -31,10 +32,19 @@ func (p PriceData) ToSetting() string {
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
+ var userGroupRatio float64
+ autoGroup, exists := c.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ info.Group = autoGroup.(string)
+ }
+ actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
if ok {
- groupRatio = userGroupRatio
+ actualGroupRatio = userGroupRatio
}
+ groupRatio = actualGroupRatio
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
diff --git a/service/quota.go b/service/quota.go
index da3dd9b9..75b186ae 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
+ "log"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -94,11 +95,20 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
+ modelRatio, _ := operation_setting.GetModelRatio(modelName)
+
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
- groupRatio = userGroupRatio
+ actualGroupRatio = userGroupRatio
}
- modelRatio, _ := operation_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -112,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
- GroupRatio: groupRatio,
+ GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -149,6 +159,13 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
@@ -290,6 +307,13 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go
index 7082b683..fdf2f723 100644
--- a/setting/user_usable_group.go
+++ b/setting/user_usable_group.go
@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}
+
+func GetUsableGroupDescription(groupName string) string {
+ if desc, ok := userUsableGroups[groupName]; ok {
+ return desc
+ }
+ return groupName
+}
diff --git a/web/src/components/settings/OperationSetting.js b/web/src/components/settings/OperationSetting.js
index 55e328a3..7bd9bf62 100644
--- a/web/src/components/settings/OperationSetting.js
+++ b/web/src/components/settings/OperationSetting.js
@@ -31,6 +31,8 @@ const OperationSetting = () => {
ModelPrice: '',
GroupRatio: '',
GroupGroupRatio: '',
+ AutoGroups: '',
+ DefaultUseAutoGroup: false,
UserUsableGroups: '',
TopUpLink: '',
'general_setting.docs_link': '',
@@ -76,6 +78,7 @@ const OperationSetting = () => {
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'GroupGroupRatio' ||
+ item.key === 'AutoGroups' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
@@ -85,7 +88,8 @@ const OperationSetting = () => {
}
if (
item.key.endsWith('Enabled') ||
- ['DefaultCollapseSidebar'].includes(item.key)
+ ['DefaultCollapseSidebar'].includes(item.key) ||
+ ['DefaultUseAutoGroup'].includes(item.key)
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
diff --git a/web/src/pages/Setting/Operation/GroupRatioSettings.js b/web/src/pages/Setting/Operation/GroupRatioSettings.js
index 6d212746..c0e1ed24 100644
--- a/web/src/pages/Setting/Operation/GroupRatioSettings.js
+++ b/web/src/pages/Setting/Operation/GroupRatioSettings.js
@@ -17,6 +17,8 @@ export default function GroupRatioSettings(props) {
GroupRatio: '',
UserUsableGroups: '',
GroupGroupRatio: '',
+ AutoGroups: '',
+ DefaultUseAutoGroup: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -167,6 +169,40 @@ export default function GroupRatioSettings(props) {
/>
+