diff --git a/controller/channel-test.go b/controller/channel-test.go
index 52f8a7ef..d162d8cf 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
- other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
- usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
@@ -312,7 +312,7 @@ func testAllChannels(notify bool) error {
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
-
+
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
diff --git a/controller/group.go b/controller/group.go
index 2c725a4d..632b6cd5 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
@@ -34,6 +35,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
diff --git a/controller/misc.go b/controller/misc.go
index 33a41302..1caaf640 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -9,9 +9,9 @@ import (
"one-api/middleware"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
- "one-api/setting/console_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -41,46 +41,47 @@ func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "min_topup": setting.MinTopUp,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "min_topup": setting.MinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
// 面板启用开关
- "api_info_enabled": cs.ApiInfoEnabled,
- "uptime_kuma_enabled": cs.UptimeKumaEnabled,
- "announcements_enabled": cs.AnnouncementsEnabled,
- "faq_enabled": cs.FAQEnabled,
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
diff --git a/controller/model.go b/controller/model.go
index df7e59a6..134217a3 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -2,7 +2,6 @@ package controller
import (
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -15,6 +14,9 @@ import (
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/setting"
+
+ "github.com/gin-gonic/gin"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -179,7 +181,19 @@ func ListModels(c *gin.Context) {
if tokenGroup != "" {
group = tokenGroup
}
- models := model.GetGroupModels(group)
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupModels(group)
+ }
for _, s := range models {
if _, ok := openAIModelsMap[s]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
diff --git a/controller/playground.go b/controller/playground.go
index a2b54790..37a5c7b0 100644
--- a/controller/playground.go
+++ b/controller/playground.go
@@ -3,7 +3,6 @@ package controller
import (
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
@@ -13,6 +12,8 @@ import (
"one-api/service"
"one-api/setting"
"time"
+
+ "github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
@@ -57,9 +58,9 @@ func Playground(c *gin.Context) {
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
- channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
+ channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
diff --git a/controller/relay.go b/controller/relay.go
index 1a875dbc..c1c45114 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+ channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
diff --git a/controller/user.go b/controller/user.go
index ecaf2583..e8ce3c3d 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1bfe1821..5d1c3641 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -49,8 +49,10 @@ func Distribute() func(c *gin.Context) {
}
// check group in common.GroupRatio
if !setting.ContainsGroupRatio(tokenGroup) {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
}
userGroup = tokenGroup
}
@@ -95,9 +97,14 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
+ var selectGroup string
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
diff --git a/model/cache.go b/model/cache.go
index e2f83e22..3e5eb4c4 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -5,10 +5,13 @@ import (
"fmt"
"math/rand"
"one-api/common"
+ "one-api/setting"
"sort"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
var group2model2channels map[string]map[string][]*Channel
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ if common.DebugEnabled {
+ println("autoGroup:", autoGroup)
+ }
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ if common.DebugEnabled {
+ println("selectGroup:", selectGroup)
+ }
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ if channel == nil {
+ return nil, group, errors.New("channel not found")
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
if strings.HasPrefix(model, "gpt-4-gizmo") {
model = "gpt-4-gizmo-*"
}
diff --git a/model/option.go b/model/option.go
index d1689cb7..1391b203 100644
--- a/model/option.go
+++ b/model/option.go
@@ -76,6 +76,8 @@ func InitOptionMap() {
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -192,7 +194,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
- if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -261,6 +263,8 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = boolValue
}
}
switch key {
@@ -287,6 +291,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
diff --git a/relay/helper/price.go b/relay/helper/price.go
index 1b52bf37..326790b4 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -11,6 +11,11 @@ import (
"github.com/gin-gonic/gin"
)
+type GroupRatioInfo struct {
+ GroupRatio float64
+ GroupSpecialRatio float64
+}
+
type PriceData struct {
ModelPrice float64
ModelRatio float64
@@ -18,23 +23,50 @@ type PriceData struct {
CacheRatio float64
CacheCreationRatio float64
ImageRatio float64
- GroupRatio float64
- UserGroupRatio float64
UsePrice bool
ShouldPreConsumedQuota int
+ GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+}
+
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
+ groupRatioInfo := GroupRatioInfo{
+ GroupRatio: 1.0, // default ratio
+ GroupSpecialRatio: 1.0, // default user group ratio
+ }
+
+ // check auto group
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("final group: %s", autoGroup))
+ }
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ // check user group special ratio
+ userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
+ if ok {
+ // user group special ratio
+ groupRatioInfo.GroupSpecialRatio = userGroupRatio
+ groupRatioInfo.GroupRatio = userGroupRatio
+ } else {
+ // normal group ratio
+ groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group)
+ }
+
+ return groupRatioInfo
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
- groupRatio := setting.GetGroupRatio(info.Group)
- userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
- if ok {
- groupRatio = userGroupRatio
- }
+
+ groupRatioInfo := HandleGroupRatio(c, info)
+
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -64,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
- ratio := modelRatio * groupRatio
+ ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
- GroupRatio: groupRatio,
- UserGroupRatio: userGroupRatio,
+ GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
diff --git a/relay/relay-image.go b/relay/relay-image.go
index dc63cce8..197a8af6 100644
--- a/relay/relay-image.go
+++ b/relay/relay-image.go
@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
// reset model price
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
- quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
+ quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
diff --git a/relay/relay-text.go b/relay/relay-text.go
index 3aa382e8..c94e0f50 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
cacheRatio := priceData.CacheRatio
imageRatio := priceData.ImageRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
- userGroupRatio := priceData.UserGroupRatio
// Convert values to decimal for precise calculation
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
@@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
if extraContent != "" {
logContent += ", " + extraContent
}
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true
other["image_ratio"] = imageRatio
diff --git a/relay/websocket.go b/relay/websocket.go
index c815eb71..571f3a82 100644
--- a/relay/websocket.go
+++ b/relay/websocket.go
@@ -6,12 +6,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
//isModelMapped = true
}
}
- //relayInfo.UpstreamModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
- //if constant.ShouldCheckPromptSensitive() {
- // err = checkRequestSensitive(textRequest, relayInfo)
- // if err != nil {
- // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
- // }
- //}
-
- //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
- //// count messages token error 计算promptTokens错误
- //if err != nil {
- // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
- //}
- //
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
- // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
- //}
- modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- relayInfo.UsePrice = true
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+ if err != nil {
+ return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
return openaiErr
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
- userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ userQuota, priceData, "")
return nil
}
diff --git a/service/quota.go b/service/quota.go
index da3dd9b9..0fb9e67c 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
+ "log"
"one-api/common"
constant2 "one-api/constant"
"one-api/dto"
@@ -94,11 +95,20 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
groupRatio := setting.GetGroupRatio(relayInfo.Group)
+ modelRatio, _ := operation_setting.GetModelRatio(modelName)
+
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.Group = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
if ok {
- groupRatio = userGroupRatio
+ actualGroupRatio = userGroupRatio
}
- modelRatio, _ := operation_setting.GetModelRatio(modelName)
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -112,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
- GroupRatio: groupRatio,
+ GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -134,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -149,11 +158,11 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
- actualGroupRatio := groupRatio
- userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
- if ok {
- actualGroupRatio = userGroupRatio
- }
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
+
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -166,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
ModelName: modelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
- GroupRatio: actualGroupRatio,
+ GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -198,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -214,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
- userGroupRatio := priceData.UserGroupRatio
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
@@ -265,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
@@ -286,16 +294,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
- actualGroupRatio := groupRatio
- userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
- if ok {
- actualGroupRatio = userGroupRatio
- }
-
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
TextTokens: textInputTokens,
@@ -308,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ModelName: relayInfo.OriginModelName,
UsePrice: usePrice,
ModelRatio: modelRatio,
- GroupRatio: actualGroupRatio,
+ GroupRatio: groupRatio,
}
quota := calculateAudioQuota(quotaInfo)
@@ -348,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go
index 7082b683..fdf2f723 100644
--- a/setting/user_usable_group.go
+++ b/setting/user_usable_group.go
@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}
+
+func GetUsableGroupDescription(groupName string) string {
+ if desc, ok := userUsableGroups[groupName]; ok {
+ return desc
+ }
+ return groupName
+}
diff --git a/web/src/components/settings/OperationSetting.js b/web/src/components/settings/OperationSetting.js
index 55e328a3..7bd9bf62 100644
--- a/web/src/components/settings/OperationSetting.js
+++ b/web/src/components/settings/OperationSetting.js
@@ -31,6 +31,8 @@ const OperationSetting = () => {
ModelPrice: '',
GroupRatio: '',
GroupGroupRatio: '',
+ AutoGroups: '',
+ DefaultUseAutoGroup: false,
UserUsableGroups: '',
TopUpLink: '',
'general_setting.docs_link': '',
@@ -76,6 +78,7 @@ const OperationSetting = () => {
item.key === 'ModelRatio' ||
item.key === 'GroupRatio' ||
item.key === 'GroupGroupRatio' ||
+ item.key === 'AutoGroups' ||
item.key === 'UserUsableGroups' ||
item.key === 'CompletionRatio' ||
item.key === 'ModelPrice' ||
@@ -85,7 +88,8 @@ const OperationSetting = () => {
}
if (
item.key.endsWith('Enabled') ||
- ['DefaultCollapseSidebar'].includes(item.key)
+ ['DefaultCollapseSidebar'].includes(item.key) ||
+ ['DefaultUseAutoGroup'].includes(item.key)
) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
diff --git a/web/src/pages/Setting/Operation/GroupRatioSettings.js b/web/src/pages/Setting/Operation/GroupRatioSettings.js
index 6d212746..4a51a98c 100644
--- a/web/src/pages/Setting/Operation/GroupRatioSettings.js
+++ b/web/src/pages/Setting/Operation/GroupRatioSettings.js
@@ -17,6 +17,8 @@ export default function GroupRatioSettings(props) {
GroupRatio: '',
UserUsableGroups: '',
GroupGroupRatio: '',
+ AutoGroups: '',
+ DefaultUseAutoGroup: false,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -167,6 +169,59 @@ export default function GroupRatioSettings(props) {
/>
+