diff --git a/controller/channel-test.go b/controller/channel-test.go index f7b73f6a..4ac618d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -17,6 +17,7 @@ import ( "one-api/relay" relaycommon "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "strconv" "strings" @@ -72,18 +73,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } } - modelMapping := *channel.ModelMapping - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[testModel] != "" { - testModel = modelMap[testModel] - } - } - cache, err := model.GetUserCache(1) if err != nil { return err, nil @@ -97,7 +86,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr middleware.SetupContextForSelectedChannel(c, channel, testModel) - meta := relaycommon.GenRelayInfo(c) + info := relaycommon.GenRelayInfo(c) + + err = helper.ModelMappedHelper(c, info) + if err != nil { + return err, nil + } + testModel = info.UpstreamModelName + apiType, _ := constant.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { @@ -105,12 +101,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } request := buildTestRequest(testModel) - meta.UpstreamModelName = testModel - common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta)) + info.OriginModelName = testModel + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info)) - adaptor.Init(meta) + adaptor.Init(info) - convertedRequest, err := adaptor.ConvertRequest(c, meta, request) + convertedRequest, err := adaptor.ConvertRequest(c, info, request) if err != nil { return err, nil } @@ -120,7 +116,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) - resp, err := adaptor.DoRequest(c, meta, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return err, nil } @@ -132,7 +128,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err } } - usageA, respErr := adaptor.DoResponse(c, httpResp, meta) + usageA, respErr := adaptor.DoResponse(c, httpResp, info) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), respErr } @@ -145,29 +141,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if err != nil { return err, nil } - modelPrice, usePrice := common.GetModelPrice(testModel, false) - modelRatio, success := common.GetModelRatio(testModel) - if !success { - return fmt.Errorf("模型 %s 倍率未设置", testModel), nil + info.PromptTokens = usage.PromptTokens + priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens)) + if err != nil { + return err, nil } - completionRatio := common.GetCompletionRatio(testModel) - ratio := modelRatio quota := 0 - if !usePrice { - quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio)) - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { + if !priceData.UsePrice { + quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) + quota = int(math.Round(float64(quota) * priceData.ModelRatio)) + if priceData.ModelRatio != 0 && quota <= 0 { quota = 1 } } else { - quota = int(modelPrice * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * common.QuotaPerUnit) } tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, priceData.ModelPrice) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, "default", other) + quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } diff --git a/controller/misc.go b/controller/misc.go index 1ea0c133..fe6b986f 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -67,6 +67,7 @@ func GetStatus(c *gin.Context) { "mj_notify_enabled": setting.MjNotifyEnabled, "chats": setting.Chats, "demo_site_enabled": setting.DemoSiteEnabled, + "self_use_mode_enabled": setting.SelfUseModeEnabled, }, }) return diff --git a/controller/pricing.go b/controller/pricing.go index d7af5a4c..97f27490 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -2,7 +2,6 @@ package controller import ( "github.com/gin-gonic/gin" - "one-api/common" "one-api/model" "one-api/setting" ) @@ -40,7 +39,7 @@ func GetPricing(c *gin.Context) { } func ResetModelRatio(c *gin.Context) { - defaultStr := common.DefaultModelRatio2JSONString() + defaultStr := setting.DefaultModelRatio2JSONString() err := model.UpdateOption("ModelRatio", defaultStr) if err != nil { c.JSON(200, gin.H{ @@ -49,7 +48,7 @@ func ResetModelRatio(c *gin.Context) { }) return } - err = common.UpdateModelRatioByJSONString(defaultStr) + err = setting.UpdateModelRatioByJSONString(defaultStr) if err != nil { c.JSON(200, gin.H{ "success": false, diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 135e0005..7986bd49 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -51,7 +51,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max // 如果在时间窗口内已达到限制,拒绝请求 subTime := nowTime.Sub(oldTime).Seconds() if int64(subTime) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) return false, nil } @@ -68,7 +68,7 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC now := time.Now().Format(timeFormat) rdb.LPush(ctx, key, now) rdb.LTrim(ctx, key, 0, int64(maxCount-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) } // Redis限流处理器 @@ -118,7 +118,7 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g // 内存限流处理器 func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) return func(c *gin.Context) { userId := strconv.Itoa(c.GetInt("id")) @@ -161,6 +161,7 @@ func ModelRequestRateLimit() func(c *gin.Context) { // 计算限流参数 duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) totalMaxCount := setting.ModelRequestRateLimitCount + successMaxCount := setting.ModelRequestRateLimitSuccessCount // 根据存储类型选择限流处理器 diff --git a/model/log.go b/model/log.go index ed7ec2c7..86850a55 100644 --- a/model/log.go +++ b/model/log.go @@ -2,12 +2,13 @@ package model import ( "fmt" - "github.com/gin-gonic/gin" "one-api/common" "os" "strings" "time" + "github.com/gin-gonic/gin" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" ) @@ -18,7 +19,7 @@ type Log struct { CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"` Content string `json:"content"` - Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` + Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"` ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` Quota int `json:"quota" gorm:"default:0"` diff --git a/model/option.go b/model/option.go index 64d15ca8..b88832c9 100644 --- a/model/option.go +++ b/model/option.go @@ -87,15 +87,15 @@ func InitOptionMap() { common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString() + common.OptionMap["ModelRatio"] = setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = setting.ModelPrice2JSONString() common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + common.OptionMap["CompletionRatio"] = setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink2"] = common.ChatLink2 @@ -111,6 +111,7 @@ func InitOptionMap() { common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled) common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled) + common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(setting.SelfUseModeEnabled) common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled) @@ -243,6 +244,8 @@ func updateOptionMap(key string, value string) (err error) { setting.CheckSensitiveEnabled = boolValue case "DemoSiteEnabled": setting.DemoSiteEnabled = boolValue + case "SelfUseModeEnabled": + setting.SelfUseModeEnabled = boolValue case "CheckSensitiveOnPromptEnabled": setting.CheckSensitiveOnPromptEnabled = boolValue case "ModelRequestRateLimitEnabled": @@ -325,7 +328,7 @@ func updateOptionMap(key string, value string) (err error) { common.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": common.QuotaRemindThreshold, _ = strconv.Atoi(value) - case "ShouldPreConsumedQuota": + case "PreConsumedQuota": common.PreConsumedQuota, _ = strconv.Atoi(value) case "ModelRequestRateLimitCount": setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value) @@ -340,15 +343,15 @@ func updateOptionMap(key string, value string) (err error) { case "DataExportDefaultTime": common.DataExportDefaultTime = value case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) + err = setting.UpdateModelRatioByJSONString(value) case "GroupRatio": err = setting.UpdateGroupRatioByJSONString(value) case "UserUsableGroups": err = setting.UpdateUserUsableGroupsByJSONString(value) case "CompletionRatio": - err = common.UpdateCompletionRatioByJSONString(value) + err = setting.UpdateCompletionRatioByJSONString(value) case "ModelPrice": - err = common.UpdateModelPriceByJSONString(value) + err = setting.UpdateModelPriceByJSONString(value) case "TopUpLink": common.TopUpLink = value case "ChatLink": diff --git a/model/pricing.go b/model/pricing.go index fc709ce4..2d0aa1b7 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "one-api/setting" "sync" "time" ) @@ -64,14 +65,14 @@ func updatePricing() { ModelName: model, EnableGroup: groups, } - modelPrice, findPrice := common.GetModelPrice(model, false) + modelPrice, findPrice := setting.GetModelPrice(model, false) if findPrice { pricing.ModelPrice = modelPrice pricing.QuotaType = 1 } else { - modelRatio, _ := common.GetModelRatio(model) + modelRatio, _ := setting.GetModelRatio(model) pricing.ModelRatio = modelRatio - pricing.CompletionRatio = common.GetCompletionRatio(model) + pricing.CompletionRatio = setting.GetCompletionRatio(model) pricing.QuotaType = 0 } pricingMap = append(pricingMap, pricing) diff --git a/relay/helper/price.go b/relay/helper/price.go index 31d4e9cf..51f64082 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -11,26 +11,33 @@ import ( type PriceData struct { ModelPrice float64 ModelRatio float64 + CompletionRatio float64 GroupRatio float64 UsePrice bool ShouldPreConsumedQuota int } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { - modelPrice, usePrice := common.GetModelPrice(info.OriginModelName, false) + modelPrice, usePrice := setting.GetModelPrice(info.OriginModelName, false) groupRatio := setting.GetGroupRatio(info.Group) var preConsumedQuota int var modelRatio float64 + var completionRatio float64 if !usePrice { preConsumedTokens := common.PreConsumedQuota if maxTokens != 0 { preConsumedTokens = promptTokens + maxTokens } var success bool - modelRatio, success = common.GetModelRatio(info.OriginModelName) + modelRatio, success = setting.GetModelRatio(info.OriginModelName) if !success { - return PriceData{}, fmt.Errorf("model %s ratio or price not found, please contact admin", info.OriginModelName) + if info.UserId == 1 { + return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) + } else { + return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置;Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName) + } } + completionRatio = setting.GetCompletionRatio(info.OriginModelName) ratio := modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { @@ -39,6 +46,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, + CompletionRatio: completionRatio, GroupRatio: groupRatio, UsePrice: usePrice, ShouldPreConsumedQuota: preConsumedQuota, diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 57de8d10..8baf033a 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -157,10 +157,10 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { @@ -463,10 +463,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) modelName := service.CoverActionToModelName(midjRequest.Action) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) // 如果没有配置价格,则使用默认价格 if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { diff --git a/relay/relay-text.go b/relay/relay-text.go index eb331e25..bf6c5fd3 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -311,7 +311,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) + completionRatio := setting.GetCompletionRatio(modelName) ratio := priceData.ModelRatio * priceData.GroupRatio modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio diff --git a/relay/relay_task.go b/relay/relay_task.go index 591ad3bb..ab35d3e8 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -37,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) - modelPrice, success := common.GetModelPrice(modelName, true) + modelPrice, success := setting.GetModelPrice(modelName, true) if !success { - defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { diff --git a/relay/websocket.go b/relay/websocket.go index 2dac60af..b0636057 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -39,7 +39,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi } } //relayInfo.UpstreamModelName = textRequest.Model - modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false) + modelPrice, getModelPriceSuccess := setting.GetModelPrice(relayInfo.UpstreamModelName, false) groupRatio := setting.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -65,7 +65,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) //} - modelRatio, _ = common.GetModelRatio(relayInfo.UpstreamModelName) + modelRatio, _ = setting.GetModelRatio(relayInfo.UpstreamModelName) ratio = modelRatio * groupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { diff --git a/service/quota.go b/service/quota.go index 9ce2858d..b3412c1e 100644 --- a/service/quota.go +++ b/service/quota.go @@ -38,9 +38,9 @@ func calculateAudioQuota(info QuotaInfo) int { return int(info.ModelPrice * common.QuotaPerUnit * info.GroupRatio) } - completionRatio := common.GetCompletionRatio(info.ModelName) - audioRatio := common.GetAudioRatio(info.ModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(info.ModelName) + completionRatio := setting.GetCompletionRatio(info.ModelName) + audioRatio := setting.GetAudioRatio(info.ModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(info.ModelName) ratio := info.GroupRatio * info.ModelRatio quota := info.InputDetails.TextTokens + int(math.Round(float64(info.OutputDetails.TextTokens)*completionRatio)) @@ -75,7 +75,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens groupRatio := setting.GetGroupRatio(relayInfo.Group) - modelRatio, _ := common.GetModelRatio(modelName) + modelRatio, _ := setting.GetModelRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -122,9 +122,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioOutTokens := usage.OutputTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + completionRatio := setting.GetCompletionRatio(modelName) + audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -184,9 +184,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(relayInfo.OriginModelName) - audioRatio := common.GetAudioRatio(relayInfo.OriginModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(relayInfo.OriginModelName) + completionRatio := setting.GetCompletionRatio(relayInfo.OriginModelName) + audioRatio := setting.GetAudioRatio(relayInfo.OriginModelName) + audioCompletionRatio := setting.GetAudioCompletionRatio(relayInfo.OriginModelName) modelRatio := priceData.ModelRatio groupRatio := priceData.GroupRatio diff --git a/service/token_counter.go b/service/token_counter.go index aa62bc6e..e868beb4 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -10,6 +10,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/setting" "strings" "unicode/utf8" @@ -32,7 +33,7 @@ func InitTokenEncoders() { if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) } - for model, _ := range common.GetDefaultModelRatioMap() { + for model, _ := range setting.GetDefaultModelRatioMap() { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = cl100TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/common/model-ratio.go b/setting/model-ratio.go similarity index 97% rename from common/model-ratio.go rename to setting/model-ratio.go index 03681172..0606f107 100644 --- a/common/model-ratio.go +++ b/setting/model-ratio.go @@ -1,7 +1,8 @@ -package common +package setting import ( "encoding/json" + "one-api/common" "strings" "sync" ) @@ -261,7 +262,7 @@ func ModelPrice2JSONString() string { GetModelPriceMap() jsonBytes, err := json.Marshal(modelPriceMap) if err != nil { - SysError("error marshalling model price: " + err.Error()) + common.SysError("error marshalling model price: " + err.Error()) } return string(jsonBytes) } @@ -285,7 +286,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { price, ok := modelPriceMap[name] if !ok { if printErr { - SysError("model price not found: " + name) + common.SysError("model price not found: " + name) } return -1, false } @@ -305,7 +306,7 @@ func ModelRatio2JSONString() string { GetModelRatioMap() jsonBytes, err := json.Marshal(modelRatioMap) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + common.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -324,8 +325,8 @@ func GetModelRatio(name string) (float64, bool) { } ratio, ok := modelRatioMap[name] if !ok { - SysError("model ratio not found: " + name) - return 37.5, false + common.SysError("model ratio not found: " + name) + return 37.5, SelfUseModeEnabled } return ratio, true } @@ -333,7 +334,7 @@ func GetModelRatio(name string) (float64, bool) { func DefaultModelRatio2JSONString() string { jsonBytes, err := json.Marshal(defaultModelRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + common.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -355,7 +356,7 @@ func CompletionRatio2JSONString() string { GetCompletionRatioMap() jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { - SysError("error marshalling completion ratio: " + err.Error()) + common.SysError("error marshalling completion ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/operation_setting.go b/setting/operation_setting.go index 4940d0fc..d4275168 100644 --- a/setting/operation_setting.go +++ b/setting/operation_setting.go @@ -3,6 +3,7 @@ package setting import "strings" var DemoSiteEnabled = false +var SelfUseModeEnabled = false var AutomaticDisableKeywords = []string{ "Your credit balance is too low", diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 71914c4e..94cd6ba8 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -15,7 +15,7 @@ import { getQuotaPerUnit, renderGroup, renderNumberWithPoint, - renderQuota, renderQuotaWithPrompt + renderQuota, renderQuotaWithPrompt, stringToColor } from '../helpers/render'; import { Button, Divider, @@ -378,17 +378,15 @@ const ChannelsTable = () => { > {t('测试')} - - - + { const [enableTagMode, setEnableTagMode] = useState(false); const [showBatchSetTag, setShowBatchSetTag] = useState(false); const [batchSetTagValue, setBatchSetTagValue] = useState(''); + const [showModelTestModal, setShowModelTestModal] = useState(false); + const [currentTestChannel, setCurrentTestChannel] = useState(null); + const [modelSearchKeyword, setModelSearchKeyword] = useState(''); const removeRecord = (record) => { @@ -1289,6 +1290,77 @@ const ChannelsTable = () => { onChange={(v) => setBatchSetTagValue(v)} /> + + {/* 模型测试弹窗 */} + { + setShowModelTestModal(false); + setModelSearchKeyword(''); + }} + footer={null} + maskClosable={true} + centered={true} + width={600} + > +
+ {currentTestChannel && ( +
+ + {t('渠道')}: {currentTestChannel.name} + + + {/* 搜索框 */} + setModelSearchKeyword(value)} + style={{ marginBottom: '16px' }} + showClear + /> + +
+ {currentTestChannel.models.split(',') + .filter(model => model.toLowerCase().includes(modelSearchKeyword.toLowerCase())) + .map((model, index) => { + + return ( + + ); + })} +
+ + {/* 显示搜索结果数量 */} + {modelSearchKeyword && ( + + {t('找到')} {currentTestChannel.models.split(',').filter(model => + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()) + ).length} {t('个模型')} + + )} +
+ )} +
+
); }; diff --git a/web/src/components/HeaderBar.js b/web/src/components/HeaderBar.js index c9105e71..68169ed2 100644 --- a/web/src/components/HeaderBar.js +++ b/web/src/components/HeaderBar.js @@ -21,15 +21,17 @@ import { IconUser, IconLanguage } from '@douyinfe/semi-icons'; -import { Avatar, Button, Dropdown, Layout, Nav, Switch } from '@douyinfe/semi-ui'; +import { Avatar, Button, Dropdown, Layout, Nav, Switch, Tag } from '@douyinfe/semi-ui'; import { stringToColor } from '../helpers/render'; import Text from '@douyinfe/semi-ui/lib/es/typography/text'; import { StyleContext } from '../context/Style/index.js'; +import { StatusContext } from '../context/Status/index.js'; const HeaderBar = () => { const { t, i18n } = useTranslation(); const [userState, userDispatch] = useContext(UserContext); const [styleState, styleDispatch] = useContext(StyleContext); + const [statusState, statusDispatch] = useContext(StatusContext); let navigate = useNavigate(); const [currentLang, setCurrentLang] = useState(i18n.language); @@ -40,6 +42,10 @@ const HeaderBar = () => { const isNewYear = (currentDate.getMonth() === 0 && currentDate.getDate() === 1); + // Check if self-use mode is enabled + const isSelfUseMode = statusState?.status?.self_use_mode_enabled || false; + const isDemoSiteMode = statusState?.status?.demo_site_enabled || false; + let buttons = [ { text: t('首页'), @@ -166,7 +172,7 @@ const HeaderBar = () => { onSelect={(key) => {}} header={styleState.isMobile?{ logo: ( - <> +
{ !styleState.showSider ?
), }:{ logo: ( logo ), - text: systemName, + text: ( +
+ {systemName} + {(isSelfUseMode || isDemoSiteMode) && ( + + {isSelfUseMode ? t('自用模式') : t('演示站点')} + + )} +
+ ), }} items={buttons} footer={ @@ -266,7 +311,8 @@ const HeaderBar = () => { icon={} /> { - !styleState.isMobile && ( + // Hide register option in self-use mode + !styleState.isMobile && !isSelfUseMode && ( { RetryTimes: 0, Chats: "[]", DemoSiteEnabled: false, + SelfUseModeEnabled: false, AutomaticDisableKeywords: '', }); diff --git a/web/src/components/OtherSetting.js b/web/src/components/OtherSetting.js index dad79fd1..e3295fb1 100644 --- a/web/src/components/OtherSetting.js +++ b/web/src/components/OtherSetting.js @@ -1,8 +1,10 @@ -import React, { useEffect, useRef, useState } from 'react'; -import { Banner, Button, Col, Form, Row } from '@douyinfe/semi-ui'; -import { API, showError, showSuccess } from '../helpers'; +import React, { useContext, useEffect, useRef, useState } from 'react'; +import { Banner, Button, Col, Form, Row, Modal, Space } from '@douyinfe/semi-ui'; +import { API, showError, showSuccess, timestamp2string } from '../helpers'; import { marked } from 'marked'; import { useTranslation } from 'react-i18next'; +import { StatusContext } from '../context/Status/index.js'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; const OtherSetting = () => { const { t } = useTranslation(); @@ -16,6 +18,7 @@ const OtherSetting = () => { }); let [loading, setLoading] = useState(false); const [showUpdateModal, setShowUpdateModal] = useState(false); + const [statusState, statusDispatch] = useContext(StatusContext); const [updateData, setUpdateData] = useState({ tag_name: '', content: '', @@ -43,6 +46,7 @@ const OtherSetting = () => { HomePageContent: false, About: false, Footer: false, + CheckUpdate: false }); const handleInputChange = async (value, e) => { const name = e.target.id; @@ -145,23 +149,48 @@ const OtherSetting = () => { } }; - const openGitHubRelease = () => { - window.location = 'https://github.com/songquanpeng/one-api/releases/latest'; - }; - const checkUpdate = async () => { - const res = await API.get( - 'https://api.github.com/repos/songquanpeng/one-api/releases/latest', - ); - const { tag_name, body } = res.data; - if (tag_name === process.env.REACT_APP_VERSION) { - showSuccess(`已是最新版本:${tag_name}`); - } else { - setUpdateData({ - tag_name: tag_name, - content: marked.parse(body), - }); - setShowUpdateModal(true); + try { + setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: true })); + // Use a CORS proxy to avoid direct cross-origin requests to GitHub API + // Option 1: Use a public CORS proxy service + // const proxyUrl = 'https://cors-anywhere.herokuapp.com/'; + // const res = await API.get( + // `${proxyUrl}https://api.github.com/repos/Calcium-Ion/new-api/releases/latest`, + // ); + + // Option 2: Use the JSON proxy approach which often works better with GitHub API + const res = await fetch( + 'https://api.github.com/repos/Calcium-Ion/new-api/releases/latest', + { + headers: { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + // Adding User-Agent which is often required by GitHub API + 'User-Agent': 'new-api-update-checker' + } + } + ).then(response => response.json()); + + // Option 3: Use a local proxy endpoint + // Create a cached version of the response to avoid frequent GitHub API calls + // const res = await API.get('/api/status/github-latest-release'); + + const { tag_name, body } = res; + if (tag_name === statusState?.status?.version) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body), + }); + setShowUpdateModal(true); + } + } catch (error) { + console.error('Failed to check for updates:', error); + showError('检查更新失败,请稍后再试'); + } finally { + setLoadingInput((loadingInput) => ({ ...loadingInput, CheckUpdate: false })); } }; const getOptions = async () => { @@ -186,9 +215,41 @@ const OtherSetting = () => { getOptions(); }, []); + // Function to open GitHub release page + const openGitHubRelease = () => { + window.open(`https://github.com/Calcium-Ion/new-api/releases/tag/${updateData.tag_name}`, '_blank'); + }; + + const getStartTimeString = () => { + const timestamp = statusState?.status?.start_time; + return statusState.status ? timestamp2string(timestamp) : ''; + }; + return ( + {/* 版本信息 */} +
+ + + + + + {t('当前版本')}:{statusState?.status?.version || t('未知')} + + + + + + + + {t('启动时间')}:{getStartTimeString()} + + + +
{/* 通用设置 */}
{
- {/* setShowUpdateModal(false)}*/} - {/* onOpen={() => setShowUpdateModal(true)}*/} - {/* open={showUpdateModal}*/} - {/*>*/} - {/* 新版本:{updateData.tag_name}*/} - {/* */} - {/* */} - {/*
*/} - {/*
*/} - {/*
*/} - {/* */} - {/* */} - {/* {*/} - {/* setShowUpdateModal(false);*/} - {/* openGitHubRelease();*/} - {/* }}*/} - {/* />*/} - {/* */} - {/**/} + setShowUpdateModal(false)} + footer={[ + + ]} + > +
+
); }; diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index 49a0784c..7e802914 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -69,7 +69,11 @@ const PersonalSetting = () => { const [models, setModels] = useState([]); const [openTransfer, setOpenTransfer] = useState(false); const [transferAmount, setTransferAmount] = useState(0); - const [isModelsExpanded, setIsModelsExpanded] = useState(false); + const [isModelsExpanded, setIsModelsExpanded] = useState(() => { + // Initialize from localStorage if available + const savedState = localStorage.getItem('modelsExpanded'); + return savedState ? JSON.parse(savedState) : false; + }); const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量 const [notificationSettings, setNotificationSettings] = useState({ warningType: 'email', @@ -124,6 +128,11 @@ const PersonalSetting = () => { } }, [userState?.user?.setting]); + // Save models expanded state to localStorage whenever it changes + useEffect(() => { + localStorage.setItem('modelsExpanded', JSON.stringify(isModelsExpanded)); + }, [isModelsExpanded]); + const handleInputChange = (name, value) => { setInputs((inputs) => ({...inputs, [name]: value})); }; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index dec74b06..5738d656 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -82,7 +82,7 @@ export const CHANNEL_OPTIONS = [ { value: 45, color: 'blue', - label: '火山方舟(豆包)' + label: '字节火山方舟、豆包、DeepSeek通用' }, { value: 25, color: 'green', label: 'Moonshot' }, { value: 19, color: 'blue', label: '360 智脑' }, diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index aa2fb2d5..3c7d368c 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1317,5 +1317,23 @@ "当前设置类型: ": "Current setting type: ", "固定价格值": "Fixed Price Value", "未设置倍率模型": "Models without ratio settings", - "模型倍率和补全倍率同时设置": "Both model ratio and completion ratio are set" + "模型倍率和补全倍率同时设置": "Both model ratio and completion ratio are set", + "自用模式": "Self-use mode", + "开启后不限制:必须设置模型倍率": "After enabling, no limit: must set model ratio", + "演示站点模式": "Demo site mode", + "当前版本": "Current version", + "Gemini设置": "Gemini settings", + "Gemini安全设置": "Gemini safety settings", + "default为默认设置,可单独设置每个分类的安全等级": "\"default\" is the default setting, and each category can be set separately", + "Gemini版本设置": "Gemini version settings", + "default为默认设置,可单独设置每个模型的版本": "\"default\" is the default setting, and each model can be set separately", + "Claude设置": "Claude settings", + "Claude请求头覆盖": "Claude request header override", + "示例": "Example", + "缺省 MaxTokens": "Default MaxTokens", + "启用Claude思考适配(-thinking后缀)": "Enable Claude thinking adaptation (-thinking suffix)", + "Claude思考适配 BudgetTokens = MaxTokens * BudgetTokens 百分比": "Claude thinking adaptation BudgetTokens = MaxTokens * BudgetTokens percentage", + "思考适配 BudgetTokens 百分比": "Thinking adaptation BudgetTokens percentage", + "0.1-1之间的小数": "Decimal between 0.1 and 1", + "模型相关设置": "Model related settings" } diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.js b/web/src/pages/Setting/Operation/SettingsGeneral.js index 1c98d33e..e46e7db2 100644 --- a/web/src/pages/Setting/Operation/SettingsGeneral.js +++ b/web/src/pages/Setting/Operation/SettingsGeneral.js @@ -22,6 +22,7 @@ export default function GeneralSettings(props) { DisplayTokenStatEnabled: false, DefaultCollapseSidebar: false, DemoSiteEnabled: false, + SelfUseModeEnabled: false, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -205,6 +206,22 @@ export default function GeneralSettings(props) { } /> + + + setInputs({ + ...inputs, + SelfUseModeEnabled: value + }) + } + /> +