diff --git a/controller/channel-test.go b/controller/channel-test.go index 52f8a7ef..d162d8cf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, - usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, + usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) @@ -312,7 +312,7 @@ func testAllChannels(notify bool) error { channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } - + if notify { service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } diff --git a/controller/group.go b/controller/group.go index 2c725a4d..632b6cd5 100644 --- a/controller/group.go +++ b/controller/group.go @@ -1,10 +1,11 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/model" "one-api/setting" + + "github.com/gin-gonic/gin" ) func GetGroups(c *gin.Context) { @@ -34,6 +35,12 @@ func GetUserGroups(c *gin.Context) { } } } + if setting.GroupInUserUsableGroups("auto") { + usableGroups["auto"] = map[string]interface{}{ + "ratio": "自动", + "desc": setting.GetUsableGroupDescription("auto"), + } + } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/misc.go b/controller/misc.go index 33a41302..1caaf640 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -9,9 +9,9 @@ import ( "one-api/middleware" "one-api/model" "one-api/setting" + "one-api/setting/console_setting" "one-api/setting/operation_setting" "one-api/setting/system_setting" - "one-api/setting/console_setting" "strings" "github.com/gin-gonic/gin" @@ -41,46 +41,47 @@ func GetStatus(c *gin.Context) { cs := console_setting.GetConsoleSetting() data := gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "linuxdo_oauth": common.LinuxDOOAuthEnabled, - "linuxdo_client_id": common.LinuxDOClientId, - "telegram_oauth": common.TelegramOAuthEnabled, - "telegram_bot_name": common.TelegramBotName, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": setting.ServerAddress, - "price": setting.Price, - "min_topup": setting.MinTopUp, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "docs_link": operation_setting.GetGeneralSetting().DocsLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, - "enable_batch_update": common.BatchUpdateEnabled, - "enable_drawing": common.DrawingEnabled, - "enable_task": common.TaskEnabled, - "enable_data_export": common.DataExportEnabled, - "data_export_default_time": common.DataExportDefaultTime, - "default_collapse_sidebar": common.DefaultCollapseSidebar, - "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", - "mj_notify_enabled": setting.MjNotifyEnabled, - "chats": setting.Chats, - "demo_site_enabled": operation_setting.DemoSiteEnabled, - "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": common.EmailVerificationEnabled, + "github_oauth": common.GitHubOAuthEnabled, + "github_client_id": common.GitHubClientId, + "linuxdo_oauth": common.LinuxDOOAuthEnabled, + "linuxdo_client_id": common.LinuxDOClientId, + "telegram_oauth": common.TelegramOAuthEnabled, + "telegram_bot_name": common.TelegramBotName, + "system_name": common.SystemName, + "logo": common.Logo, + "footer_html": common.Footer, + "wechat_qrcode": common.WeChatAccountQRCodeImageURL, + "wechat_login": common.WeChatAuthEnabled, + "server_address": setting.ServerAddress, + "price": setting.Price, + "min_topup": setting.MinTopUp, + "turnstile_check": common.TurnstileCheckEnabled, + "turnstile_site_key": common.TurnstileSiteKey, + "top_up_link": common.TopUpLink, + "docs_link": operation_setting.GetGeneralSetting().DocsLink, + "quota_per_unit": common.QuotaPerUnit, + "display_in_currency": common.DisplayInCurrencyEnabled, + "enable_batch_update": common.BatchUpdateEnabled, + "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, + "enable_data_export": common.DataExportEnabled, + "data_export_default_time": common.DataExportDefaultTime, + "default_collapse_sidebar": common.DefaultCollapseSidebar, + "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", + "mj_notify_enabled": setting.MjNotifyEnabled, + "chats": setting.Chats, + "demo_site_enabled": operation_setting.DemoSiteEnabled, + "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "default_use_auto_group": setting.DefaultUseAutoGroup, // 面板启用开关 - "api_info_enabled": cs.ApiInfoEnabled, - "uptime_kuma_enabled": cs.UptimeKumaEnabled, - "announcements_enabled": cs.AnnouncementsEnabled, - "faq_enabled": cs.FAQEnabled, + "api_info_enabled": cs.ApiInfoEnabled, + "uptime_kuma_enabled": cs.UptimeKumaEnabled, + "announcements_enabled": cs.AnnouncementsEnabled, + "faq_enabled": cs.FAQEnabled, "oidc_enabled": system_setting.GetOIDCSettings().Enabled, "oidc_client_id": system_setting.GetOIDCSettings().ClientId, diff --git a/controller/model.go b/controller/model.go index df7e59a6..134217a3 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,7 +2,6 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/constant" @@ -15,6 +14,9 @@ import ( "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/setting" + + "github.com/gin-gonic/gin" ) // https://platform.openai.com/docs/api-reference/models/list @@ -179,7 +181,19 @@ func ListModels(c *gin.Context) { if tokenGroup != "" { group = tokenGroup } - models := model.GetGroupModels(group) + var models []string + if tokenGroup == "auto" { + for _, autoGroup := range setting.AutoGroups { + groupModels := model.GetGroupModels(autoGroup) + for _, g := range groupModels { + if !common.StringsContains(models, g) { + models = append(models, g) + } + } + } + } else { + models = model.GetGroupModels(group) + } for _, s := range models { if _, ok := openAIModelsMap[s]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) diff --git a/controller/playground.go b/controller/playground.go index a2b54790..37a5c7b0 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -3,7 +3,6 @@ package controller import ( "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/constant" @@ -13,6 +12,8 @@ import ( "one-api/service" "one-api/setting" "time" + + "github.com/gin-gonic/gin" ) func Playground(c *gin.Context) { @@ -57,9 +58,9 @@ func Playground(c *gin.Context) { c.Set("group", group) } c.Set("token_name", "playground-"+group) - channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0) + channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model) openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) return } diff --git a/controller/relay.go b/controller/relay.go index 1a875dbc..c1c45114 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) } @@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i) if err != nil { common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break diff --git a/controller/user.go b/controller/user.go index ecaf2583..e8ce3c3d 100644 --- a/controller/user.go +++ b/controller/user.go @@ -226,6 +226,9 @@ func Register(c *gin.Context) { UnlimitedQuota: true, ModelLimitsEnabled: false, } + if setting.DefaultUseAutoGroup { + token.Group = "auto" + } if err := token.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/middleware/distributor.go b/middleware/distributor.go index 1bfe1821..5d1c3641 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -49,8 +49,10 @@ func Distribute() func(c *gin.Context) { } // check group in common.GroupRatio if !setting.ContainsGroupRatio(tokenGroup) { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) - return + if tokenGroup != "auto" { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) + return + } } userGroup = tokenGroup } @@ -95,9 +97,14 @@ func Distribute() func(c *gin.Context) { } if shouldSelectChannel { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) + var selectGroup string + channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + showGroup := userGroup + if userGroup == "auto" { + showGroup = fmt.Sprintf("auto(%s)", selectGroup) + } + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model) // 如果错误,但是渠道不为空,说明是数据库一致性问题 if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) diff --git a/model/cache.go b/model/cache.go index e2f83e22..3e5eb4c4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -5,10 +5,13 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/setting" "sort" "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) var group2model2channels map[string]map[string][]*Channel @@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) { + var channel *Channel + var err error + selectGroup := group + if group == "auto" { + if len(setting.AutoGroups) == 0 { + return nil, selectGroup, errors.New("auto groups is not enabled") + } + for _, autoGroup := range setting.AutoGroups { + if common.DebugEnabled { + println("autoGroup:", autoGroup) + } + channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry) + if channel == nil { + continue + } else { + c.Set("auto_group", autoGroup) + selectGroup = autoGroup + if common.DebugEnabled { + println("selectGroup:", selectGroup) + } + break + } + } + } else { + channel, err = getRandomSatisfiedChannel(group, model, retry) + if err != nil { + return nil, group, err + } + } + if channel == nil { + return nil, group, errors.New("channel not found") + } + return channel, selectGroup, nil +} + +func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { if strings.HasPrefix(model, "gpt-4-gizmo") { model = "gpt-4-gizmo-*" } diff --git a/model/option.go b/model/option.go index d1689cb7..1391b203 100644 --- a/model/option.go +++ b/model/option.go @@ -76,6 +76,8 @@ func InitOptionMap() { common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() + common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup) common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["TelegramBotToken"] = "" @@ -192,7 +194,7 @@ func updateOptionMap(key string, value string) (err error) { common.ImageDownloadPermission = intValue } } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": @@ -261,6 +263,8 @@ func updateOptionMap(key string, value string) (err error) { common.SMTPSSLEnabled = boolValue case "WorkerAllowHttpImageRequestEnabled": setting.WorkerAllowHttpImageRequestEnabled = boolValue + case "DefaultUseAutoGroup": + setting.DefaultUseAutoGroup = boolValue } } switch key { @@ -287,6 +291,8 @@ func updateOptionMap(key string, value string) (err error) { setting.PayAddress = value case "Chats": err = setting.UpdateChatsByJsonString(value) + case "AutoGroups": + err = setting.UpdateAutoGroupsByJsonString(value) case "CustomCallbackAddress": setting.CustomCallbackAddress = value case "EpayId": diff --git a/relay/helper/price.go b/relay/helper/price.go index 1b52bf37..326790b4 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -11,6 +11,11 @@ import ( "github.com/gin-gonic/gin" ) +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 +} + type PriceData struct { ModelPrice float64 ModelRatio float64 @@ -18,23 +23,50 @@ type PriceData struct { CacheRatio float64 CacheCreationRatio float64 ImageRatio float64 - GroupRatio float64 - UserGroupRatio float64 UsePrice bool ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} + +// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { + groupRatioInfo := GroupRatioInfo{ + GroupRatio: 1.0, // default ratio + GroupSpecialRatio: 1.0, // default user group ratio + } + + // check auto group + autoGroup, exists := ctx.Get("auto_group") + if exists { + if common.DebugEnabled { + println(fmt.Sprintf("final group: %s", autoGroup)) + } + relayInfo.Group = autoGroup.(string) + } + + // check user group special ratio + userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) + if ok { + // user group special ratio + groupRatioInfo.GroupSpecialRatio = userGroupRatio + groupRatioInfo.GroupRatio = userGroupRatio + } else { + // normal group ratio + groupRatioInfo.GroupRatio = setting.GetGroupRatio(relayInfo.Group) + } + + return groupRatioInfo } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false) - groupRatio := setting.GetGroupRatio(info.Group) - userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group) - if ok { - groupRatio = userGroupRatio - } + + groupRatioInfo := HandleGroupRatio(c, info) + var preConsumedQuota int var modelRatio float64 var completionRatio float64 @@ -64,18 +96,17 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName) - ratio := modelRatio * groupRatio + ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } priceData := PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, - GroupRatio: groupRatio, - UserGroupRatio: userGroupRatio, + GroupRatioInfo: groupRatioInfo, UsePrice: usePrice, CacheRatio: cacheRatio, ImageRatio: imageRatio, diff --git a/relay/relay-image.go b/relay/relay-image.go index dc63cce8..197a8af6 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // reset model price priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) diff --git a/relay/relay-text.go b/relay/relay-text.go index 3aa382e8..c94e0f50 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -361,9 +361,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, cacheRatio := priceData.CacheRatio imageRatio := priceData.ImageRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -511,7 +510,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio diff --git a/relay/websocket.go b/relay/websocket.go index c815eb71..571f3a82 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -6,12 +6,10 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" - "one-api/setting" - "one-api/setting/operation_setting" ) func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //isModelMapped = true } } - //relayInfo.UpstreamModelName = textRequest.Model - modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - //err := service.SensitiveWordsCheck(textRequest) - - //if constant.ShouldCheckPromptSensitive() { - // err = checkRequestSensitive(textRequest, relayInfo) - // if err != nil { - // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) - // } - //} - - //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo) - //// count messages token error 计算promptTokens错误 - //if err != nil { - // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) - //} - // - if !getModelPriceSuccess { - preConsumedTokens := common.PreConsumedQuota - //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { - // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) - //} - modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - relayInfo.UsePrice = true + priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi return openaiErr } service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + userQuota, priceData, "") return nil } diff --git a/service/quota.go b/service/quota.go index da3dd9b9..0fb9e67c 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,6 +3,7 @@ package service import ( "errors" "fmt" + "log" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -94,11 +95,20 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens groupRatio := setting.GetGroupRatio(relayInfo.Group) + modelRatio, _ := operation_setting.GetModelRatio(modelName) + + autoGroup, exists := ctx.Get("auto_group") + if exists { + groupRatio = setting.GetGroupRatio(autoGroup.(string)) + log.Printf("final group ratio: %f", groupRatio) + relayInfo.Group = autoGroup.(string) + } + + actualGroupRatio := groupRatio userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) if ok { - groupRatio = userGroupRatio + actualGroupRatio = userGroupRatio } - modelRatio, _ := operation_setting.GetModelRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -112,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag ModelName: modelName, UsePrice: relayInfo.UsePrice, ModelRatio: modelRatio, - GroupRatio: groupRatio, + GroupRatio: actualGroupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -134,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { + usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -149,11 +158,11 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName)) - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice + quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -166,7 +175,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod ModelName: modelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -198,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } @@ -214,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, tokenName := ctx.GetString("token_name") completionRatio := priceData.CompletionRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio cacheRatio := priceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens @@ -265,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } @@ -286,16 +294,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice usePrice := priceData.UsePrice - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } - quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -308,7 +310,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ModelName: relayInfo.OriginModelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -348,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } diff --git a/setting/auto_group.go b/setting/auto_group.go new file mode 100644 index 00000000..5a87ae56 --- /dev/null +++ b/setting/auto_group.go @@ -0,0 +1,31 @@ +package setting + +import "encoding/json" + +var AutoGroups = []string{ + "default", +} + +var DefaultUseAutoGroup = false + +func ContainsAutoGroup(group string) bool { + for _, autoGroup := range AutoGroups { + if autoGroup == group { + return true + } + } + return false +} + +func UpdateAutoGroupsByJsonString(jsonString string) error { + AutoGroups = make([]string, 0) + return json.Unmarshal([]byte(jsonString), &AutoGroups) +} + +func AutoGroups2JsonString() string { + jsonBytes, err := json.Marshal(AutoGroups) + if err != nil { + return "[]" + } + return string(jsonBytes) +} diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 7082b683..fdf2f723 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool { _, ok := userUsableGroups[groupName] return ok } + +func GetUsableGroupDescription(groupName string) string { + if desc, ok := userUsableGroups[groupName]; ok { + return desc + } + return groupName +} diff --git a/web/src/components/settings/OperationSetting.js b/web/src/components/settings/OperationSetting.js index 55e328a3..7bd9bf62 100644 --- a/web/src/components/settings/OperationSetting.js +++ b/web/src/components/settings/OperationSetting.js @@ -31,6 +31,8 @@ const OperationSetting = () => { ModelPrice: '', GroupRatio: '', GroupGroupRatio: '', + AutoGroups: '', + DefaultUseAutoGroup: false, UserUsableGroups: '', TopUpLink: '', 'general_setting.docs_link': '', @@ -76,6 +78,7 @@ const OperationSetting = () => { item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'GroupGroupRatio' || + item.key === 'AutoGroups' || item.key === 'UserUsableGroups' || item.key === 'CompletionRatio' || item.key === 'ModelPrice' || @@ -85,7 +88,8 @@ const OperationSetting = () => { } if ( item.key.endsWith('Enabled') || - ['DefaultCollapseSidebar'].includes(item.key) + ['DefaultCollapseSidebar'].includes(item.key) || + ['DefaultUseAutoGroup'].includes(item.key) ) { newInputs[item.key] = item.value === 'true' ? true : false; } else { diff --git a/web/src/pages/Setting/Operation/GroupRatioSettings.js b/web/src/pages/Setting/Operation/GroupRatioSettings.js index 6d212746..4a51a98c 100644 --- a/web/src/pages/Setting/Operation/GroupRatioSettings.js +++ b/web/src/pages/Setting/Operation/GroupRatioSettings.js @@ -17,6 +17,8 @@ export default function GroupRatioSettings(props) { GroupRatio: '', UserUsableGroups: '', GroupGroupRatio: '', + AutoGroups: '', + DefaultUseAutoGroup: false, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -167,6 +169,59 @@ export default function GroupRatioSettings(props) { /> + + + { + if (!value || value.trim() === '') { + return true; // Allow empty values + } + + // First check if it's valid JSON + try { + const parsed = JSON.parse(value); + + // Check if it's an array + if (!Array.isArray(parsed)) { + return false; + } + + // Check if every element is a string + return parsed.every(item => typeof item === 'string'); + } catch (error) { + return false; + } + }, + message: t('必须是有效的 JSON 字符串数组,例如:["g1","g2"]'), + }, + ]} + onChange={(value) => + setInputs({ ...inputs, AutoGroups: value }) + } + /> + + + + + + setInputs({ ...inputs, DefaultUseAutoGroup: value }) + } + /> + + diff --git a/web/src/pages/Token/EditToken.js b/web/src/pages/Token/EditToken.js index 71f611bd..782562a3 100644 --- a/web/src/pages/Token/EditToken.js +++ b/web/src/pages/Token/EditToken.js @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useState, useContext } from 'react'; import { useNavigate } from 'react-router-dom'; import { API, @@ -7,7 +7,7 @@ import { showSuccess, timestamp2string, renderGroupOption, - renderQuotaWithPrompt + renderQuotaWithPrompt, } from '../../helpers'; import { AutoComplete, @@ -37,11 +37,13 @@ import { IconPlusCircle, } from '@douyinfe/semi-icons'; import { useTranslation } from 'react-i18next'; +import { StatusContext } from '../../context/Status'; const { Text, Title } = Typography; const EditToken = (props) => { const { t } = useTranslation(); + const [statusState, statusDispatch] = useContext(StatusContext); const [isEdit, setIsEdit] = useState(false); const [loading, setLoading] = useState(isEdit); const originInputs = { @@ -119,7 +121,19 @@ const EditToken = (props) => { value: group, ratio: info.ratio, })); + if (statusState?.status?.default_use_auto_group) { + // if contain auto, add it to the first position + if (localGroupOptions.some((group) => group.value === 'auto')) { + // 排序 + localGroupOptions.sort((a, b) => (a.value === 'auto' ? -1 : 1)); + } else { + localGroupOptions.unshift({ label: t('自动选择'), value: 'auto' }); + } + } setGroups(localGroupOptions); + if (statusState?.status?.default_use_auto_group) { + setInputs({ ...inputs, group: 'auto' }); + } } else { showError(t(message)); } @@ -268,32 +282,37 @@ const EditToken = (props) => { placement={isEdit ? 'right' : 'left'} title={ - {isEdit ? - {t('更新')} : - {t('新建')} - } - + {isEdit ? ( + <Tag color='blue' shape='circle'> + {t('更新')} + </Tag> + ) : ( + <Tag color='green' shape='circle'> + {t('新建')} + </Tag> + )} + <Title heading={4} className='m-0'> {isEdit ? t('更新令牌信息') : t('创建新的令牌')} } headerStyle={{ borderBottom: '1px solid var(--semi-color-border)', - padding: '24px' + padding: '24px', }} bodyStyle={{ backgroundColor: 'var(--semi-color-bg-0)', - padding: '0' + padding: '0', }} visible={props.visiable} width={isMobile() ? '100%' : 600} footer={ -
+
- -
-
-
-
+ +
+
+
+
-
- +
+
-
- {t('额度设置')} -
{t('设置令牌可用额度和数量')}
+
+ + {t('额度设置')} + +
+ {t('设置令牌可用额度和数量')} +
-
+
-
+
{t('额度')} - {renderQuotaWithPrompt(remain_quota)} + + {renderQuotaWithPrompt(remain_quota)} +
handleInputChange('remain_quota', value)} value={remain_quota} - autoComplete="new-password" - type="number" - size="large" - className="w-full !rounded-lg" + autoComplete='new-password' + type='number' + size='large' + className='w-full !rounded-lg' prefix={} data={[ { value: 500000, label: '1$' }, @@ -460,16 +517,18 @@ const EditToken = (props) => { {!isEdit && (
- {t('新建数量')} + + {t('新建数量')} + handleTokenCountChange(value)} onSelect={(value) => handleTokenCountChange(value)} value={tokenCount.toString()} - autoComplete="off" - type="number" - className="w-full !rounded-lg" - size="large" + autoComplete='off' + type='number' + className='w-full !rounded-lg' + size='large' prefix={} data={[ { value: 10, label: t('10个') }, @@ -482,12 +541,12 @@ const EditToken = (props) => {
)} -
+
@@ -495,92 +554,137 @@ const EditToken = (props) => {
- -
-
-
-
+ +
+
+
+
-
- +
+
-
- {t('访问限制')} -
{t('设置令牌的访问限制')}
+
+ + {t('访问限制')} + +
+ {t('设置令牌的访问限制')} +
-
+
- {t('IP白名单')} + + {t('IP白名单')} +