diff --git a/README.en.md b/README.en.md index 23fdbe1f..4709bc5b 100644 --- a/README.en.md +++ b/README.en.md @@ -107,7 +107,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env - `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` - `CRYPTO_SECRET`: Encryption key used for encrypting database content -- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` diff --git a/README.md b/README.md index 67af9916..a807b07d 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认 `16` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位MB,默认 `20` - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 -- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2024-12-01-preview` +- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,默认 `2025-04-01-preview` - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` diff --git a/common/constants.go b/common/constants.go index dd4f3b04..bee00506 100644 --- a/common/constants.go +++ b/common/constants.go @@ -240,6 +240,7 @@ const ( ChannelTypeBaiduV2 = 46 ChannelTypeXinference = 47 ChannelTypeXai = 48 + ChannelTypeCoze = 49 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{ "https://qianfan.baidubce.com", //46 "", //47 "https://api.x.ai", //48 + "https://api.coze.cn", //49 } diff --git a/constant/azure.go b/constant/azure.go new file mode 100644 index 00000000..d84040ce --- /dev/null +++ b/constant/azure.go @@ -0,0 +1,5 @@ +package constant + +import "time" + +var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix() diff --git a/constant/env.go b/constant/env.go index fae48625..612f3e8b 100644 --- a/constant/env.go +++ b/constant/env.go @@ -31,7 +31,7 @@ func InitEnv() { GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") + AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 41f8d8f7..2bda0fd2 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -108,6 +108,13 @@ type DeepSeekUsageResponse struct { } `json:"balance_infos"` } +type OpenRouterCreditResponse struct { + Data struct { + TotalCredits float64 `json:"total_credits"` + TotalUsage float64 `json:"total_usage"` + } `json:"data"` +} + // GetAuthHeader get auth header func GetAuthHeader(token string) http.Header { h := http.Header{} @@ -281,6 +288,22 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { return response.TotalAvailable, nil } +func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) { + url := "https://openrouter.ai/api/v1/credits" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := OpenRouterCreditResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + balance := response.Data.TotalCredits - response.Data.TotalUsage + channel.UpdateBalance(balance) + return balance, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { @@ -307,6 +330,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) { return updateChannelSiliconFlowBalance(channel) case common.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) + case common.ChannelTypeOpenRouter: + return updateChannelOpenRouterBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/option.go b/controller/option.go index 81ef463c..250f16bb 100644 --- a/controller/option.go +++ b/controller/option.go @@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) { }) return } + case "ModelRequestRateLimitGroup": + err = setting.CheckModelRequestRateLimitGroup(option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } } err = model.UpdateOption(option.Key, option.Value) diff --git a/controller/user.go b/controller/user.go index e194f531..fd53e743 100644 --- a/controller/user.go +++ b/controller/user.go @@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) { user.Password = "" // rollback to what it should be cleanUser.Password = "" } - updatePassword := user.Password != "" + updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } if err := cleanUser.Update(updatePassword); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) { return } +func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) { + var currentUser *model.User + currentUser, err = model.GetUserById(userId, true) + if err != nil { + return + } + if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) { + err = fmt.Errorf("原密码错误") + return + } + if newPassword == "" { + return + } + updatePassword = true + return +} + func DeleteUser(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { diff --git a/dto/dalle.go b/dto/dalle.go index 562d5f1a..44104d33 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,6 +12,8 @@ type ImageRequest struct { Style string `json:"style,omitempty"` User string `json:"user,omitempty"` ExtraFields json.RawMessage `json:"extra_fields,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` } type ImageResponse struct { diff --git a/dto/openai_response.go b/dto/openai_response.go index 1508d1f6..790d4df8 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -195,28 +195,28 @@ type OutputTokenDetails struct { } type OpenAIResponsesResponse struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int `json:"created_at"` - Status string `json:"status"` - Error *OpenAIError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` - Instructions string `json:"instructions"` - MaxOutputTokens int `json:"max_output_tokens"` - Model string `json:"model"` - Output []ResponsesOutput `json:"output"` - ParallelToolCalls bool `json:"parallel_tool_calls"` - PreviousResponseID string `json:"previous_response_id"` - Reasoning *Reasoning `json:"reasoning"` - Store bool `json:"store"` - Temperature float64 `json:"temperature"` - ToolChoice string `json:"tool_choice"` - Tools []interface{} `json:"tools"` - TopP float64 `json:"top_p"` - Truncation string `json:"truncation"` - Usage *Usage `json:"usage"` - User json.RawMessage `json:"user"` - Metadata json.RawMessage `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Status string `json:"status"` + Error *OpenAIError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + Instructions string `json:"instructions"` + MaxOutputTokens int `json:"max_output_tokens"` + Model string `json:"model"` + Output []ResponsesOutput `json:"output"` + ParallelToolCalls bool `json:"parallel_tool_calls"` + PreviousResponseID string `json:"previous_response_id"` + Reasoning *Reasoning `json:"reasoning"` + Store bool `json:"store"` + Temperature float64 `json:"temperature"` + ToolChoice string `json:"tool_choice"` + Tools []ResponsesToolsCall `json:"tools"` + TopP float64 `json:"top_p"` + Truncation string `json:"truncation"` + Usage *Usage `json:"usage"` + User json.RawMessage `json:"user"` + Metadata json.RawMessage `json:"metadata"` } type IncompleteDetails struct { @@ -238,8 +238,12 @@ type ResponsesOutputContent struct { } const ( - BuildInTools_WebSearch = "web_search_preview" - BuildInTools_FileSearch = "file_search" + BuildInToolWebSearchPreview = "web_search_preview" + BuildInToolFileSearch = "file_search" +) + +const ( + BuildInCallWebSearchCall = "web_search_call" ) const ( @@ -250,6 +254,7 @@ const ( // ResponsesStreamResponse 用于处理 /v1/responses 流式响应 type ResponsesStreamResponse struct { Type string `json:"type"` - Response *OpenAIResponsesResponse `json:"response"` + Response *OpenAIResponsesResponse `json:"response,omitempty"` Delta string `json:"delta,omitempty"` + Item *ResponsesOutput `json:"item,omitempty"` } diff --git a/middleware/distributor.go b/middleware/distributor.go index 51fd8fd1..e7db6d77 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { - modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") + modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { relayMode := relayconstant.RelayModeAudioSpeech @@ -213,6 +213,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) + c.Set("channel_create_time", channel.CreatedTime) c.Set("channel_setting", channel.GetSetting()) c.Set("param_override", channel.GetParamOverride()) if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { @@ -239,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeMokaAI: c.Set("api_version", channel.Other) + case common.ChannelTypeCoze: + c.Set("bot_id", channel.Other) } } diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 581dc451..34caa59b 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/common/limiter" + "one-api/constant" "one-api/setting" "strconv" "time" @@ -93,25 +94,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g } //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 - totalKey := fmt.Sprintf("rateLimit:%s", userId) - // 初始化 - tb := limiter.New(ctx, rdb) - allowed, err = tb.Allow( - ctx, - totalKey, - limiter.WithCapacity(int64(totalMaxCount)*duration), - limiter.WithRate(int64(totalMaxCount)), - limiter.WithRequested(duration), - ) + if totalMaxCount > 0 { + totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } } // 4. 处理请求 @@ -173,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) { totalMaxCount := setting.ModelRequestRateLimitCount successMaxCount := setting.ModelRequestRateLimitSuccessCount + // 获取分组 + group := c.GetString("token_group") + if group == "" { + group = c.GetString(constant.ContextKeyUserGroup) + } + + //获取分组的限流配置 + groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) + if found { + totalMaxCount = groupTotalCount + successMaxCount = groupSuccessCount + } + // 根据存储类型选择并执行限流处理器 if common.RedisEnabled { redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) diff --git a/model/option.go b/model/option.go index d575742f..d892b120 100644 --- a/model/option.go +++ b/model/option.go @@ -67,6 +67,7 @@ func InitOptionMap() { common.OptionMap["ServerAddress"] = "" common.OptionMap["WorkerUrl"] = setting.WorkerUrl common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey + common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled) common.OptionMap["PayAddress"] = "" common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" @@ -92,6 +93,7 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) + common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() @@ -256,6 +258,8 @@ func updateOptionMap(key string, value string) (err error) { setting.StopOnSensitiveEnabled = boolValue case "SMTPSSLEnabled": common.SMTPSSLEnabled = boolValue + case "WorkerAllowHttpImageRequestEnabled": + setting.WorkerAllowHttpImageRequestEnabled = boolValue } } switch key { @@ -338,6 +342,8 @@ func updateOptionMap(key string, value string) (err error) { setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) case "ModelRequestRateLimitSuccessCount": setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) + case "ModelRequestRateLimitGroup": + err = setting.UpdateModelRequestRateLimitGroupByJSONString(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) case "DataExportInterval": diff --git a/model/user.go b/model/user.go index 0aea2ff5..1a3372aa 100644 --- a/model/user.go +++ b/model/user.go @@ -18,6 +18,7 @@ type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database! DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 8e34fd80..ab632d22 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -33,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) case constant.RelayModeImagesGenerations: fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) default: fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 8b2ca889..03eff9cf 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -1,16 +1,23 @@ package channel import ( + "context" "errors" "fmt" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "io" "net/http" common2 "one-api/common" "one-api/relay/common" "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" + "sync" + "time" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { @@ -55,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } + if common2.DebugEnabled { + println("fullRequestURL:", fullRequestURL) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) @@ -105,7 +115,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } else { client = service.GetHttpClient() } + // 流式请求 ping 保活 + var stopPinger func() + generalSettings := operation_setting.GetGeneralSetting() + pingEnabled := generalSettings.PingIntervalEnabled + var pingerWg sync.WaitGroup + if info.IsStream { + helper.SetEventStreamHeaders(c) + pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second + var pingerCtx context.Context + pingerCtx, stopPinger = context.WithCancel(c.Request.Context()) + + if pingEnabled { + pingerWg.Add(1) + gopool.Go(func() { + defer pingerWg.Done() + if pingInterval <= 0 { + pingInterval = helper.DefaultPingInterval + } + + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + var pingMutex sync.Mutex + if common2.DebugEnabled { + println("SSE ping goroutine started") + } + + for { + select { + case <-ticker.C: + pingMutex.Lock() + err2 := helper.PingData(c) + pingMutex.Unlock() + if err2 != nil { + common2.LogError(c, "SSE ping error: "+err.Error()) + return + } + if common2.DebugEnabled { + println("SSE ping data sent.") + } + case <-pingerCtx.Done(): + if common2.DebugEnabled { + println("SSE ping goroutine stopped.") + } + return + } + } + }) + } + } + resp, err := client.Do(req) + // request结束后停止ping + if info.IsStream && pingEnabled { + stopPinger() + pingerWg.Wait() + } if err != nil { return nil, err } diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go new file mode 100644 index 00000000..80441a51 --- /dev/null +++ b/relay/channel/coze/adaptor.go @@ -0,0 +1,132 @@ +package coze + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + "one-api/relay/common" + "time" + + "github.com/gin-gonic/gin" +) + +type Adaptor struct { +} + +// ConvertAudioRequest implements channel.Adaptor. +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + return nil, errors.New("not implemented") +} + +// ConvertClaudeRequest implements channel.Adaptor. +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertEmbeddingRequest implements channel.Adaptor. +func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertImageRequest implements channel.Adaptor. +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertOpenAIRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return convertCozeChatRequest(c, *request), nil +} + +// ConvertOpenAIResponsesRequest implements channel.Adaptor. +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// ConvertRerankRequest implements channel.Adaptor. +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +// DoRequest implements channel.Adaptor. +func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) { + if info.IsStream { + return channel.DoApiRequest(a, c, info, requestBody) + } + // 首先发送创建消息请求,成功后再发送获取消息请求 + // 发送创建消息请求 + resp, err := channel.DoApiRequest(a, c, info, requestBody) + if err != nil { + return nil, err + } + // 解析 resp + var cozeResponse CozeChatResponse + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(respBody, &cozeResponse) + if cozeResponse.Code != 0 { + return nil, errors.New(cozeResponse.Msg) + } + c.Set("coze_conversation_id", cozeResponse.Data.ConversationId) + c.Set("coze_chat_id", cozeResponse.Data.Id) + // 轮询检查消息是否完成 + for { + err, isComplete := checkIfChatComplete(a, c, info) + if err != nil { + return nil, err + } else { + if isComplete { + break + } + } + time.Sleep(time.Second * 1) + } + // 发送获取消息请求 + return getChatDetail(a, c, info) +} + +// DoResponse implements channel.Adaptor. +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cozeChatStreamHandler(c, resp, info) + } else { + err, usage = cozeChatHandler(c, resp, info) + } + return +} + +// GetChannelName implements channel.Adaptor. +func (a *Adaptor) GetChannelName() string { + return ChannelName +} + +// GetModelList implements channel.Adaptor. +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +// GetRequestURL implements channel.Adaptor. +func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil +} + +// Init implements channel.Adaptor. +func (a *Adaptor) Init(info *common.RelayInfo) { + +} + +// SetupRequestHeader implements channel.Adaptor. +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go new file mode 100644 index 00000000..873ffe24 --- /dev/null +++ b/relay/channel/coze/constants.go @@ -0,0 +1,30 @@ +package coze + +var ModelList = []string{ + "moonshot-v1-8k", + "moonshot-v1-32k", + "moonshot-v1-128k", + "Baichuan4", + "abab6.5s-chat-pro", + "glm-4-0520", + "qwen-max", + "deepseek-r1", + "deepseek-v3", + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-7b", + "step-1v-8k", + "step-1.5v-mini", + "Doubao-pro-32k", + "Doubao-pro-256k", + "Doubao-lite-128k", + "Doubao-lite-32k", + "Doubao-vision-lite-32k", + "Doubao-vision-pro-32k", + "Doubao-1.5-pro-vision-32k", + "Doubao-1.5-lite-32k", + "Doubao-1.5-pro-32k", + "Doubao-1.5-thinking-pro", + "Doubao-1.5-pro-256k", +} + +var ChannelName = "coze" diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go new file mode 100644 index 00000000..4e9afa23 --- /dev/null +++ b/relay/channel/coze/dto.go @@ -0,0 +1,78 @@ +package coze + +import "encoding/json" + +type CozeError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type CozeEnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ContentType string `json:"content_type,omitempty"` +} + +type CozeChatRequest struct { + BotId string `json:"bot_id"` + UserId string `json:"user_id"` + AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream,omitempty"` + CustomVariables json.RawMessage `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history,omitempty"` + MetaData json.RawMessage `json:"meta_data,omitempty"` + ExtraParams json.RawMessage `json:"extra_params,omitempty"` + ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type CozeChatResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data CozeChatResponseData `json:"data"` +} + +type CozeChatResponseData struct { + Id string `json:"id"` + ConversationId string `json:"conversation_id"` + BotId string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + LastError CozeError `json:"last_error"` + Status string `json:"status"` + Usage CozeChatUsage `json:"usage"` +} + +type CozeChatUsage struct { + TokenCount int `json:"token_count"` + OutputCount int `json:"output_count"` + InputCount int `json:"input_count"` +} + +type CozeChatDetailResponse struct { + Data []CozeChatV3MessageDetail `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` + Detail CozeResponseDetail `json:"detail"` +} + +type CozeChatV3MessageDetail struct { + Id string `json:"id"` + Role string `json:"role"` + Type string `json:"type"` + BotId string `json:"bot_id"` + ChatId string `json:"chat_id"` + Content json.RawMessage `json:"content"` + MetaData json.RawMessage `json:"meta_data"` + CreatedAt int64 `json:"created_at"` + SectionId string `json:"section_id"` + UpdatedAt int64 `json:"updated_at"` + ContentType string `json:"content_type"` + ConversationId string `json:"conversation_id"` + ReasoningContent string `json:"reasoning_content"` +} + +type CozeResponseDetail struct { + Logid string `json:"logid"` +} diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go new file mode 100644 index 00000000..6db40213 --- /dev/null +++ b/relay/channel/coze/relay-coze.go @@ -0,0 +1,300 @@ +package coze + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest { + var messages []CozeEnterMessage + // 将 request的messages的role为user的content转换为CozeMessage + for _, message := range request.Messages { + if message.Role == "user" { + messages = append(messages, CozeEnterMessage{ + Role: "user", + Content: message.Content, + // TODO: support more content type + ContentType: "text", + }) + } + } + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } + cozeRequest := &CozeChatRequest{ + BotId: c.GetString("bot_id"), + UserId: user, + AdditionalMessages: messages, + Stream: request.Stream, + } + return cozeRequest +} + +func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // convert coze response to openai response + var response dto.TextResponse + var cozeResponse CozeChatDetailResponse + response.Model = info.UpstreamModelName + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if cozeResponse.Code != 0 { + return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil + } + // 从上下文获取 usage + var usage dto.Usage + usage.PromptTokens = c.GetInt("coze_input_count") + usage.CompletionTokens = c.GetInt("coze_output_count") + usage.TotalTokens = c.GetInt("coze_token_count") + response.Usage = usage + response.Id = helper.GetResponseID(c) + + var responseContent json.RawMessage + for _, data := range cozeResponse.Data { + if data.Type == "answer" { + responseContent = data.Content + response.Created = data.CreatedAt + } + } + // 添加 response.Choices + response.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Role: "assistant", Content: responseContent}, + FinishReason: "stop", + }, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + return nil, &usage +} + +func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + helper.SetEventStreamHeaders(c) + id := helper.GetResponseID(c) + var responseText string + + var currentEvent string + var currentData string + var usage dto.Usage + + for scanner.Scan() { + line := scanner.Text() + + if line == "" { + if currentEvent != "" && currentData != "" { + // handle last event + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + currentEvent = "" + currentData = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + currentEvent = strings.TrimSpace(line[6:]) + continue + } + + if strings.HasPrefix(line, "data:") { + currentData = strings.TrimSpace(line[5:]) + continue + } + } + + // Last event + if currentEvent != "" && currentData != "" { + handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + } + + if err := scanner.Err(); err != nil { + return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil + } + helper.Done(c) + + if usage.TotalTokens == 0 { + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + return nil, &usage +} + +func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { + switch event { + case "conversation.chat.completed": + // 将 data 解析为 CozeChatResponseData + var chatData CozeChatResponseData + err := json.Unmarshal([]byte(data), &chatData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + usage.PromptTokens = chatData.Usage.InputCount + usage.CompletionTokens = chatData.Usage.OutputCount + usage.TotalTokens = chatData.Usage.TokenCount + + finishReason := "stop" + stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason) + helper.ObjectData(c, stopResponse) + + case "conversation.message.delta": + // 将 data 解析为 CozeChatV3MessageDetail + var messageData CozeChatV3MessageDetail + err := json.Unmarshal([]byte(data), &messageData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + var content string + err = json.Unmarshal(messageData.Content, &content) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + *responseText += content + + openaiResponse := dto.ChatCompletionsStreamResponse{ + Id: id, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: info.UpstreamModelName, + } + + choice := dto.ChatCompletionsStreamResponseChoice{ + Index: 0, + } + choice.Delta.SetContentString(content) + openaiResponse.Choices = append(openaiResponse.Choices, choice) + + helper.ObjectData(c, openaiResponse) + + case "error": + var errorData CozeError + err := json.Unmarshal([]byte(data), &errorData) + if err != nil { + common.SysError("error_unmarshalling_stream_response: " + err.Error()) + return + } + + common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + } +} + +func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + // 将 conversationId和chatId作为参数发送get请求 + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return err, false + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return err, false + } + + resp, err := doRequest(req, info) // 调用 doRequest + if err != nil { + return err, false + } + if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic + return fmt.Errorf("resp is nil"), false + } + defer resp.Body.Close() // 确保响应体被关闭 + + // 解析 resp 到 CozeChatResponse + var cozeResponse CozeChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body failed: %w", err), false + } + err = json.Unmarshal(responseBody, &cozeResponse) + if err != nil { + return fmt.Errorf("unmarshal response body failed: %w", err), false + } + if cozeResponse.Data.Status == "completed" { + // 在上下文设置 usage + c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount) + c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount) + c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount) + return nil, true + } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" { + return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false + } else { + return nil, false + } +} + +func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + + requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + err = a.SetupRequestHeader(c, &req.Header, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(req, info) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + +func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) { + var client *http.Client + var err error // 声明 err 变量 + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return nil, fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + resp, err := client.Do(req) + if err != nil { // 增加对 client.Do(req) 返回错误的检查 + return nil, fmt.Errorf("client.Do failed: %w", err) + } + // _ = resp.Body.Close() + return resp, nil +} diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index dbe65528..ae9a3b7b 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -391,6 +391,7 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac } // 删除所有的title字段 delete(v, "title") + delete(v, "$schema") // 如果type不为object和array,则直接返回 if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") { return schema diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 7740c498..f0cf073f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -8,6 +8,7 @@ import ( "io" "mime/multipart" "net/http" + "net/textproto" "one-api/common" constant2 "one-api/constant" "one-api/dto" @@ -25,8 +26,6 @@ import ( "path/filepath" "strings" - "net/textproto" - "github.com/gin-gonic/gin" ) @@ -68,9 +67,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } - if info.RelayMode == constant.RelayModeResponses { - return fmt.Sprintf("%s/v1/responses", info.BaseUrl), nil - } if info.RelayMode == constant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") @@ -93,7 +89,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") model_ := info.UpstreamModelName - model_ = strings.Replace(model_, ".", "", -1) + // 2025年5月10日后创建的渠道不移除. + if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime { + model_ = strings.Replace(model_, ".", "", -1) + } // https://github.com/songquanpeng/one-api/issues/67 requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) if info.RelayMode == constant.RelayModeRealtime { @@ -173,7 +172,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn info.UpstreamModelName = request.Model // o系列模型developer适配(o1-mini除外) - if !strings.HasPrefix(request.Model, "o1-mini") { + if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") { //修改第一个Message的内容,将system改为developer if len(request.Messages) > 0 && request.Messages[0].Role == "system" { request.Messages[0].Role = "developer" @@ -429,7 +428,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { - err, usage = OpenaiResponsesHandler(c, resp, info) + err, usage = OaiResponsesHandler(c, resp, info) } default: if info.IsStream { diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index e7ba2e7b..a068c544 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } } } + +func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { + if data == "" { + return + } + helper.ResponseChunkData(c, streamResponse, data) +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index ef660564..86c47a15 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -215,10 +215,35 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } + + forceFormat := false + if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { + forceFormat = forceFmt + } + + if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { + completionTokens := 0 + for _, choice := range simpleResponse.Choices { + ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + completionTokens += ctkm + } + simpleResponse.Usage = dto.Usage{ + PromptTokens: info.PromptTokens, + CompletionTokens: completionTokens, + TotalTokens: info.PromptTokens + completionTokens, + } + } switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: - break + if forceFormat { + responseBody, err = json.Marshal(simpleResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + } else { + break + } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := json.Marshal(claudeResp) @@ -244,18 +269,6 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI common.SysError("error copying response body: " + err.Error()) } resp.Body.Close() - if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { - completionTokens := 0 - for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) - completionTokens += ctkm - } - simpleResponse.Usage = dto.Usage{ - PromptTokens: info.PromptTokens, - CompletionTokens: completionTokens, - TotalTokens: info.PromptTokens + completionTokens, - } - } return nil, &simpleResponse.Usage } @@ -644,102 +657,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm } return nil, &usageResp.Usage } - -func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - // read response body - var responsesResponse dto.OpenAIResponsesResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &responsesResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil - } - if responsesResponse.Error != nil { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: responsesResponse.Error.Message, - Type: "openai_error", - Code: responsesResponse.Error.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } - - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() - // compute usage - usage := dto.Usage{} - usage.PromptTokens = responsesResponse.Usage.InputTokens - usage.CompletionTokens = responsesResponse.Usage.OutputTokens - usage.TotalTokens = responsesResponse.Usage.TotalTokens - return nil, &usage -} - -func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - if resp == nil || resp.Body == nil { - common.LogError(c, "invalid response or response body") - return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil - } - - var usage = &dto.Usage{} - var responseTextBuilder strings.Builder - - helper.StreamScannerHandler(c, resp, info, func(data string) bool { - - // 检查当前数据是否包含 completed 状态和 usage 信息 - var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { - sendResponsesStreamData(c, streamResponse, data) - switch streamResponse.Type { - case "response.completed": - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens - case "response.output_text.delta": - // 处理输出文本 - responseTextBuilder.WriteString(streamResponse.Delta) - - } - } - return true - }) - - if usage.CompletionTokens == 0 { - // 计算输出文本的 token 数量 - tempStr := responseTextBuilder.String() - if len(tempStr) > 0 { - // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) - usage.CompletionTokens = completionTokens - } - } - - return nil, usage -} - -func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) { - if data == "" { - return - } - helper.ResponseChunkData(c, streamResponse, data) -} diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go new file mode 100644 index 00000000..1d1e060e --- /dev/null +++ b/relay/channel/openai/relay_responses.go @@ -0,0 +1,119 @@ +package openai + +import ( + "bytes" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "strings" + + "github.com/gin-gonic/gin" +) + +func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + // read response body + var responsesResponse dto.OpenAIResponsesResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = common.DecodeJson(responseBody, &responsesResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if responsesResponse.Error != nil { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: responsesResponse.Error.Message, + Type: "openai_error", + Code: responsesResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + // reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + // copy response body + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + common.SysError("error copying response body: " + err.Error()) + } + resp.Body.Close() + // compute usage + usage := dto.Usage{} + usage.PromptTokens = responsesResponse.Usage.InputTokens + usage.CompletionTokens = responsesResponse.Usage.OutputTokens + usage.TotalTokens = responsesResponse.Usage.TotalTokens + // 解析 Tools 用量 + for _, tool := range responsesResponse.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++ + } + return nil, &usage +} + +func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + if resp == nil || resp.Body == nil { + common.LogError(c, "invalid response or response body") + return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil + } + + var usage = &dto.Usage{} + var responseTextBuilder strings.Builder + + helper.StreamScannerHandler(c, resp, info, func(data string) bool { + + // 检查当前数据是否包含 completed 状态和 usage 信息 + var streamResponse dto.ResponsesStreamResponse + if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + sendResponsesStreamData(c, streamResponse, data) + switch streamResponse.Type { + case "response.completed": + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + case "response.output_text.delta": + // 处理输出文本 + responseTextBuilder.WriteString(streamResponse.Delta) + case dto.ResponsesOutputTypeItemDone: + // 函数调用处理 + if streamResponse.Item != nil { + switch streamResponse.Item.Type { + case dto.BuildInCallWebSearchCall: + info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++ + } + } + } + } + return true + }) + + if usage.CompletionTokens == 0 { + // 计算输出文本的 token 数量 + tempStr := responseTextBuilder.String() + if len(tempStr) > 0 { + // 非正常结束,使用输出文本的 token 数量 + completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + usage.CompletionTokens = completionTokens + } + } + + return nil, usage +} diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index c1b64f11..7daf9a61 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -11,8 +11,8 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/gemini" "one-api/relay/channel/openai" - "one-api/setting/model_setting" relaycommon "one-api/relay/common" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 12634c84..b5896415 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -2,14 +2,16 @@ package xai import ( "errors" - "fmt" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "strings" + "one-api/relay/constant" + "github.com/gin-gonic/gin" ) @@ -28,15 +30,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - request.Size = "" - return request, nil + xaiRequest := ImageRequest{ + Model: request.Model, + Prompt: request.Prompt, + N: request.N, + ResponseFormat: request.ResponseFormat, + } + return xaiRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -89,15 +96,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = xAIStreamHandler(c, resp, info) - } else { - err, usage = xAIHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) + default: + if info.IsStream { + err, usage = xAIStreamHandler(c, resp, info) + } else { + err, usage = xAIHandler(c, resp, info) + } } - //if _, ok := usage.(*dto.Usage); ok && usage != nil { - // usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens - //} - return } diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go index 7036d5f1..b8098475 100644 --- a/relay/channel/xai/dto.go +++ b/relay/channel/xai/dto.go @@ -12,3 +12,16 @@ type ChatCompletionResponse struct { Usage *dto.Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` } + +// quality, size or style are not supported by xAI API at the moment. +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + // Size string `json:"size,omitempty"` + // Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + // Style string `json:"style,omitempty"` + // User string `json:"user,omitempty"` + // ExtraFields json.RawMessage `json:"extra_fields,omitempty"` +} \ No newline at end of file diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 915474e1..f4fc3c1e 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -36,6 +36,7 @@ type ClaudeConvertInfo struct { const ( RelayFormatOpenAI = "openai" RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" ) type RerankerInfo struct { @@ -43,6 +44,16 @@ type RerankerInfo struct { ReturnDocuments bool } +type BuildInToolInfo struct { + ToolName string + CallCount int + SearchContextSize string +} + +type ResponsesUsageInfo struct { + BuiltInTools map[string]*BuildInToolInfo +} + type RelayInfo struct { ChannelType int ChannelId int @@ -87,9 +98,11 @@ type RelayInfo struct { UserQuota int RelayFormat string SendResponseCount int + ChannelCreateTime int64 ThinkingContentInfo *ClaudeConvertInfo *RerankerInfo + *ResponsesUsageInfo } // 定义支持流式选项的通道类型 @@ -103,6 +116,8 @@ var streamSupportedChannels = map[int]bool{ common.ChannelTypeVolcEngine: true, common.ChannelTypeOllama: true, common.ChannelTypeXai: true, + common.ChannelTypeDeepSeek: true, + common.ChannelTypeBaiduV2: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { @@ -134,6 +149,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { + info := GenRelayInfo(c) + info.RelayMode = relayconstant.RelayModeResponses + info.ResponsesUsageInfo = &ResponsesUsageInfo{ + BuiltInTools: make(map[string]*BuildInToolInfo), + } + if len(req.Tools) > 0 { + for _, tool := range req.Tools { + info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{ + ToolName: tool.Type, + CallCount: 0, + } + switch tool.Type { + case dto.BuildInToolWebSearchPreview: + if tool.SearchContextSize == "" { + tool.SearchContextSize = "medium" + } + info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize + } + } + } + info.IsStream = req.Stream + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -170,14 +210,15 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { OriginModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"), //RecodeModelName: c.GetString("original_model"), - IsModelMapped: false, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), - ChannelSetting: channelSetting, - ParamOverride: paramOverride, - RelayFormat: RelayFormatOpenAI, + IsModelMapped: false, + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), + ChannelSetting: channelSetting, + ChannelCreateTime: c.GetInt64("channel_create_time"), + ParamOverride: paramOverride, + RelayFormat: RelayFormatOpenAI, ThinkingContentInfo: ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index fef38f23..3f1ecd78 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -33,6 +33,7 @@ const ( APITypeOpenRouter APITypeXinference APITypeXai + APITypeCoze APITypeDummy // this one is only for count, do not add any channel after this ) @@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeXinference case common.ChannelTypeXai: apiType = APITypeXai + case common.ChannelTypeCoze: + apiType = APITypeCoze } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/helper/common.go b/relay/helper/common.go index 6a8ca2d7..35d983f7 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -12,11 +12,19 @@ import ( ) func SetEventStreamHeaders(c *gin.Context) { - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("Transfer-Encoding", "chunked") - c.Writer.Header().Set("X-Accel-Buffering", "no") + // 检查是否已经设置过头部 + if _, exists := c.Get("event_stream_headers_set"); exists { + return + } + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + // 设置标志,表示头部已经设置过 + c.Set("event_stream_headers_set", true) } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { @@ -37,7 +45,7 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) - c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) if flusher, ok := c.Writer.(http.Flusher); ok { flusher.Flush() } @@ -45,7 +53,7 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) - c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) + c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) if flusher, ok := c.Writer.(http.Flusher); ok { flusher.Flush() } diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 948c5226..9bf67c03 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -2,9 +2,11 @@ package helper import ( "encoding/json" + "errors" "fmt" - "github.com/gin-gonic/gin" "one-api/relay/common" + + "github.com/gin-gonic/gin" ) func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { @@ -16,9 +18,36 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { if err != nil { return fmt.Errorf("unmarshal_model_mapping_failed") } - if modelMap[info.OriginModelName] != "" { - info.UpstreamModelName = modelMap[info.OriginModelName] - info.IsModelMapped = true + + // 支持链式模型重定向,最终使用链尾的模型 + currentModel := info.OriginModelName + visitedModels := map[string]bool{ + currentModel: true, + } + for { + if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" { + // 模型重定向循环检测,避免无限循环 + if visitedModels[mappedModel] { + if mappedModel == currentModel { + if currentModel == info.OriginModelName { + info.IsModelMapped = false + return nil + } else { + info.IsModelMapped = true + break + } + } + return errors.New("model_mapping_contains_cycle") + } + visitedModels[mappedModel] = true + currentModel = mappedModel + info.IsModelMapped = true + } else { + break + } + } + if info.IsModelMapped { + info.UpstreamModelName = currentModel } } return nil diff --git a/relay/helper/price.go b/relay/helper/price.go index 899c72b9..89efa1da 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -23,7 +23,7 @@ type PriceData struct { } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index 2738ce2a..c1bc0d6e 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -3,7 +3,6 @@ package helper import ( "bufio" "context" - "github.com/bytedance/gopkg/util/gopool" "io" "net/http" "one-api/common" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" ) diff --git a/relay/relay-image.go b/relay/relay-image.go index 70219cc1..daed3d80 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -49,11 +49,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e") } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3") } if imageRequest.Quality == "" { imageRequest.Quality = "standard" diff --git a/relay/relay-responses.go b/relay/relay-responses.go index cdb37ae7..fd3ddb5a 100644 --- a/relay/relay-responses.go +++ b/relay/relay-responses.go @@ -19,7 +19,7 @@ import ( "github.com/gin-gonic/gin" ) -func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { +func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) { request := &dto.OpenAIResponsesRequest{} err := common.UnmarshalBodyReusable(c, request) if err != nil { @@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay if len(request.Input) == 0 { return nil, errors.New("input is required") } - relayInfo.IsStream = request.Stream return request, nil } func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { - sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) return sensitiveWords, err } @@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) - req, err := getAndValidateResponsesRequest(c, relayInfo) + req, err := getAndValidateResponsesRequest(c) if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) } + + relayInfo := relaycommon.GenRelayInfoResponses(c, req) + if setting.ShouldCheckPromptSensitive() { sensitiveWords, err := checkInputSensitive(req, relayInfo) if err != nil { diff --git a/relay/relay-text.go b/relay/relay-text.go index 4fdd435d..8d5cd384 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -18,6 +18,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/model_setting" + "one-api/setting/operation_setting" "strings" "time" @@ -193,6 +194,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -358,6 +360,34 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ratio := dModelRatio.Mul(dGroupRatio) + // openai web search 工具计费 + var dWebSearchQuota decimal.Decimal + var webSearchPrice float64 + if relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 { + // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率) + webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize) + dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). + Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 $%s", + webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) + } + } + // file search tool 计费 + var dFileSearchQuota decimal.Decimal + var fileSearchPrice float64 + if relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 { + fileSearchPrice = operation_setting.GetFileSearchPricePerThousand() + dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). + Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). + Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) + extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s", + fileSearchTool.CallCount, dFileSearchQuota.String()) + } + } + var quotaCalculateDecimal decimal.Decimal if !priceData.UsePrice { nonCachedTokens := dPromptTokens.Sub(dCacheTokens) @@ -380,6 +410,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + // 添加 responses tools call 调用的配额 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) + quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -430,6 +463,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { + other["web_search"] = true + other["web_search_call_count"] = webSearchTool.CallCount + other["web_search_price"] = webSearchPrice + } + } + if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil { + if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists { + other["file_search"] = true + other["file_search_call_count"] = fileSearchTool.CallCount + other["file_search_price"] = fileSearchPrice + } + } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8b4afcb3..7bf0da9f 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -10,6 +10,7 @@ import ( "one-api/relay/channel/claude" "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" + "one-api/relay/channel/coze" "one-api/relay/channel/deepseek" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &openai.Adaptor{} case constant.APITypeXai: return &xai.Adaptor{} + case constant.APITypeCoze: + return &coze.Adaptor{} } return nil } diff --git a/service/cf_worker.go b/service/cf_worker.go index 40a1e294..ae6e1ffe 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -24,7 +24,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { if !setting.EnableWorker() { return nil, fmt.Errorf("worker not enabled") } - if !strings.HasPrefix(req.URL, "https") { + if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") { return nil, fmt.Errorf("only support https url") } diff --git a/service/http_client.go b/service/http_client.go index c3f8df7a..64a361cf 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -3,12 +3,13 @@ package service import ( "context" "fmt" - "golang.org/x/net/proxy" "net" "net/http" "net/url" "one-api/common" "time" + + "golang.org/x/net/proxy" ) var httpClient *http.Client @@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { }, }, nil - case "socks5": + case "socks5", "socks5h": // 获取认证信息 var auth *proxy.Auth if parsedURL.User != nil { @@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { } // 创建 SOCKS5 代理拨号器 + // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同 dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) if err != nil { return nil, err diff --git a/service/token_counter.go b/service/token_counter.go index 21b882af..d63b54ad 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -120,11 +120,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m var config image.Config var err error var format string + var b64str string if strings.HasPrefix(imageUrl.Url, "http") { config, format, err = DecodeUrlImageData(imageUrl.Url) } else { common.SysLog(fmt.Sprintf("decoding image")) - config, format, _, err = DecodeBase64ImageData(imageUrl.Url) + config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url) } if err != nil { return 0, err @@ -132,7 +133,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { - return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) + // not an image + if format != "" && b64str != "" { + // file type + return 3 * baseTokens, nil + } + return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url)) } shortSide := config.Width diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go new file mode 100644 index 00000000..974c4ed2 --- /dev/null +++ b/setting/operation_setting/tools.go @@ -0,0 +1,57 @@ +package operation_setting + +import "strings" + +const ( + // Web search + WebSearchHighTierModelPriceLow = 30.00 + WebSearchHighTierModelPriceMedium = 35.00 + WebSearchHighTierModelPriceHigh = 50.00 + WebSearchPriceLow = 25.00 + WebSearchPriceMedium = 27.50 + WebSearchPriceHigh = 30.00 + // File search + FileSearchPrice = 2.5 +) + +func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 { + // 确定模型类型 + // https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费 + // gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵,gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜 + isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) && + !strings.Contains(modelName, "mini") + // 确定 search context size 对应的价格 + var priceWebSearchPerThousandCalls float64 + switch contextSize { + case "low": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceLow + } else { + priceWebSearchPerThousandCalls = WebSearchPriceLow + } + case "medium": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium + } else { + priceWebSearchPerThousandCalls = WebSearchPriceMedium + } + case "high": + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceHigh + } else { + priceWebSearchPerThousandCalls = WebSearchPriceHigh + } + default: + // search context size 默认为 medium + if isHighTierModel { + priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium + } else { + priceWebSearchPerThousandCalls = WebSearchPriceMedium + } + } + return priceWebSearchPerThousandCalls +} + +func GetFileSearchPricePerThousand() float64 { + return FileSearchPrice +} diff --git a/setting/rate_limit.go b/setting/rate_limit.go index 4b216948..53b53f88 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -1,6 +1,64 @@ package setting +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" +) + var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitSuccessCount = 1000 +var ModelRequestRateLimitGroup = map[string][2]int{} +var ModelRequestRateLimitMutex sync.RWMutex + +func ModelRequestRateLimitGroup2JSONString() string { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) + if err != nil { + common.SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + ModelRequestRateLimitGroup = make(map[string][2]int) + return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup) +} + +func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { + ModelRequestRateLimitMutex.RLock() + defer ModelRequestRateLimitMutex.RUnlock() + + if ModelRequestRateLimitGroup == nil { + return 0, 0, false + } + + limits, found := ModelRequestRateLimitGroup[group] + if !found { + return 0, 0, false + } + return limits[0], limits[1], true +} + +func CheckModelRequestRateLimitGroup(jsonStr string) error { + checkModelRequestRateLimitGroup := make(map[string][2]int) + err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup) + if err != nil { + return err + } + for group, limits := range checkModelRequestRateLimitGroup { + if limits[0] < 0 || limits[1] < 1 { + return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1]) + } + } + + return nil +} diff --git a/setting/system_setting.go b/setting/system_setting.go index 15017d3d..c37a6123 100644 --- a/setting/system_setting.go +++ b/setting/system_setting.go @@ -3,6 +3,7 @@ package setting var ServerAddress = "http://localhost:3000" var WorkerUrl = "" var WorkerValidKey = "" +var WorkerAllowHttpImageRequestEnabled = false func EnableWorker() bool { return WorkerUrl != "" diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 903677eb..6cf7e844 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -618,7 +618,6 @@ const LogsTable = () => { ); } - let content = other?.claude ? renderClaudeModelPriceSimple( other.model_ratio, @@ -935,6 +934,13 @@ const LogsTable = () => { other.model_price, other.group_ratio, other?.user_group_ratio, + false, + 1.0, + undefined, + other.web_search || false, + other.web_search_call_count || 0, + other.file_search || false, + other.file_search_call_count || 0, ), }); } @@ -995,6 +1001,12 @@ const LogsTable = () => { other?.image || false, other?.image_ratio || 0, other?.image_output || 0, + other?.web_search || false, + other?.web_search_call_count || 0, + other?.web_search_price || 0, + other?.file_search || false, + other?.file_search_call_count || 0, + other?.file_search_price || 0, ); } expandDataLocal.push({ diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js index d1e03db2..0f52c319 100644 --- a/web/src/components/PersonalSetting.js +++ b/web/src/components/PersonalSetting.js @@ -57,6 +57,7 @@ const PersonalSetting = () => { email_verification_code: '', email: '', self_account_deletion_confirmation: '', + original_password: '', set_new_password: '', set_new_password_confirmation: '', }); @@ -239,11 +240,24 @@ const PersonalSetting = () => { }; const changePassword = async () => { + if (inputs.original_password === '') { + showError(t('请输入原密码!')); + return; + } + if (inputs.set_new_password === '') { + showError(t('请输入新密码!')); + return; + } + if (inputs.original_password === inputs.set_new_password) { + showError(t('新密码需要和原密码不一致!')); + return; + } if (inputs.set_new_password !== inputs.set_new_password_confirmation) { showError(t('两次输入的密码不一致!')); return; } const res = await API.put(`/api/user/self`, { + original_password: inputs.original_password, password: inputs.set_new_password, }); const { success, message } = res.data; @@ -816,8 +830,8 @@ const PersonalSetting = () => { - - + +
{t('通知方式')}
@@ -993,23 +1007,36 @@ const PersonalSetting = () => {
- +
- {t('接受未设置价格模型')} + + {t('接受未设置价格模型')} +
handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)} + checked={ + notificationSettings.acceptUnsetModelRatioModel + } + onChange={(e) => + handleNotificationSettingChange( + 'acceptUnsetModelRatioModel', + e.target.checked, + ) + } > {t('接受未设置价格模型')} - - {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} + + {t( + '当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用', + )}
-
@@ -799,7 +812,13 @@ const SystemSetting = () => { onChange={(value) => setEmailToAdd(value)} style={{ marginTop: 16 }} suffix={ - + } onEnterPress={handleAddEmail} /> diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index fa59bcce..054da535 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [ { value: 48, color: 'blue', - label: 'xAI' - } + label: 'xAI', + }, + { + value: 49, + color: 'blue', + label: 'Coze', + }, ]; diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 7b80da6f..5a59356b 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -317,6 +317,12 @@ export function renderModelPrice( image = false, imageRatio = 1.0, imageOutputTokens = 0, + webSearch = false, + webSearchCallCount = 0, + webSearchPrice = 0, + fileSearch = false, + fileSearchCallCount = 0, + fileSearchPrice = 0, ) { if (modelPrice !== -1) { return i18next.t( @@ -339,14 +345,17 @@ export function renderModelPrice( // Calculate effective input tokens (non-cached + cached with ratio applied) let effectiveInputTokens = inputTokens - cacheTokens + cacheTokens * cacheRatio; -// Handle image tokens if present + // Handle image tokens if present if (image && imageOutputTokens > 0) { - effectiveInputTokens = inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; + effectiveInputTokens = + inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; } let price = (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + - (completionTokens / 1000000) * completionRatioPrice * groupRatio; + (completionTokens / 1000000) * completionRatioPrice * groupRatio + + (webSearchCallCount / 1000) * webSearchPrice * groupRatio + + (fileSearchCallCount / 1000) * fileSearchPrice * groupRatio; return ( <> @@ -391,9 +400,23 @@ export function renderModelPrice( )}

)} + {webSearch && webSearchCallCount > 0 && ( +

+ {i18next.t('Web搜索价格:${{price}} / 1K 次', { + price: webSearchPrice, + })} +

+ )} + {fileSearch && fileSearchCallCount > 0 && ( +

+ {i18next.t('文件搜索价格:${{price}} / 1K 次', { + price: fileSearchPrice, + })} +

+ )}

- {cacheTokens > 0 && !image + {cacheTokens > 0 && !image && !webSearch && !fileSearch ? i18next.t( '输入 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', { @@ -407,31 +430,82 @@ export function renderModelPrice( total: price.toFixed(6), }, ) - : image && imageOutputTokens > 0 - ? i18next.t( - '输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', - { - nonImageInput: inputTokens - imageOutputTokens, - imageInput: imageOutputTokens, - imageRatio: imageRatio, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - total: price.toFixed(6), - }, - ) - : i18next.t( - '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', - { - input: inputTokens, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - total: price.toFixed(6), - }, - )} + : image && imageOutputTokens > 0 && !webSearch && !fileSearch + ? i18next.t( + '输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', + { + nonImageInput: inputTokens - imageOutputTokens, + imageInput: imageOutputTokens, + imageRatio: imageRatio, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6), + }, + ) + : webSearch && webSearchCallCount > 0 && !image && !fileSearch + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + webSearchCallCount, + webSearchPrice, + total: price.toFixed(6), + }, + ) + : fileSearch && + fileSearchCallCount > 0 && + !image && + !webSearch + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + fileSearchCallCount, + fileSearchPrice, + total: price.toFixed(6), + }, + ) + : webSearch && + webSearchCallCount > 0 && + fileSearch && + fileSearchCallCount > 0 && + !image + ? i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}}+ 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + webSearchCallCount, + webSearchPrice, + fileSearchCallCount, + fileSearchPrice, + total: price.toFixed(6), + }, + ) + : i18next.t( + '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + total: price.toFixed(6), + }, + )}

{i18next.t('仅供参考,以实际扣费为准')}

@@ -448,33 +522,56 @@ export function renderLogContent( user_group_ratio, image = false, imageRatio = 1.0, - useUserGroupRatio = undefined + useUserGroupRatio = undefined, + webSearch = false, + webSearchCallCount = 0, + fileSearch = false, + fileSearchCallCount = 0, ) { - const ratioLabel = useUserGroupRatio ? i18next.t('专属倍率') : i18next.t('分组倍率'); + const ratioLabel = useUserGroupRatio + ? i18next.t('专属倍率') + : i18next.t('分组倍率'); const ratio = useUserGroupRatio ? user_group_ratio : groupRatio; if (modelPrice !== -1) { return i18next.t('模型价格 ${{price}},{{ratioType}} {{ratio}}', { price: modelPrice, ratioType: ratioLabel, - ratio + ratio, }); } else { if (image) { - return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', { - modelRatio: modelRatio, - completionRatio: completionRatio, - imageRatio: imageRatio, - ratioType: ratioLabel, - ratio - }); + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}},{{ratioType}} {{ratio}}', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + imageRatio: imageRatio, + ratioType: ratioLabel, + ratio, + }, + ); + } else if (webSearch) { + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}},Web 搜索调用 {{webSearchCallCount}} 次', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio, + webSearchCallCount, + }, + ); } else { - return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', { - modelRatio: modelRatio, - completionRatio: completionRatio, - ratioType: ratioLabel, - ratio - }); + return i18next.t( + '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},{{ratioType}} {{ratio}}', + { + modelRatio: modelRatio, + completionRatio: completionRatio, + ratioType: ratioLabel, + ratio, + }, + ); } } } diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index e9975f61..916329e7 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -493,6 +493,7 @@ "默认": "default", "图片演示": "Image demo", "注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41", + "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "After May 10, 2025, channels added do not need to remove the dot in the model name during deployment", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "取消无限额度": "Cancel unlimited quota", "取消": "Cancel", @@ -1085,7 +1086,7 @@ "没有账户?": "No account? ", "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "默认 API 版本": "Default API Version", - "请输入默认 API 版本,例如:2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", + "请输入默认 API 版本,例如:2025-04-01-preview": "Please enter default API version, e.g.: 2025-04-01-preview.", "请为渠道命名": "Please name the channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", @@ -1373,4 +1374,4 @@ "适用于展示系统功能的场景。": "Suitable for scenarios where the system functions are displayed.", "可在初始化后修改": "Can be modified after initialization", "初始化系统": "Initialize system" -} +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index a793e149..f7fab057 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -24,7 +24,8 @@ import { TextArea, Checkbox, Banner, - Modal, ImagePreview + Modal, + ImagePreview, } from '@douyinfe/semi-ui'; import { getChannelModels, loadChannelModels } from '../../components/utils.js'; import { IconHelpCircle } from '@douyinfe/semi-icons'; @@ -306,7 +307,7 @@ const EditChannel = (props) => { fetchModels().then(); fetchGroups().then(); if (isEdit) { - loadChannel().then(() => { }); + loadChannel().then(() => {}); } else { setInputs(originInputs); let localModels = getChannelModels(inputs.type); @@ -477,24 +478,26 @@ const EditChannel = (props) => { type={'warning'} description={ <> - {t('注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41')} -
- { - setModalImageUrl( - '/azure_model_name.png', - ); - setIsModalOpenurl(true) + {t( + '2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的"."', + )} + {/*
*/} + {/* {*/} + {/* setModalImageUrl(*/} + {/* '/azure_model_name.png',*/} + {/* );*/} + {/* setIsModalOpenurl(true)*/} - }} - > - {t('查看示例')} -
+ {/* }}*/} + {/*>*/} + {/* {t('查看示例')}*/} + {/**/} } > @@ -522,7 +525,7 @@ const EditChannel = (props) => { { handleInputChange('other', value); }} @@ -584,25 +587,35 @@ const EditChannel = (props) => { value={inputs.name} autoComplete='new-password' /> - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( - <> -
- {t('API地址')}: -
- - { - handleInputChange('base_url', value); - }} - value={inputs.base_url} - autoComplete="new-password" - /> - - - )} + {inputs.type !== 3 && + inputs.type !== 8 && + inputs.type !== 22 && + inputs.type !== 36 && + inputs.type !== 45 && ( + <> +
+ {t('API地址')}: +
+ + { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + + )}
{t('密钥')}:
@@ -761,10 +774,10 @@ const EditChannel = (props) => { name='other' placeholder={t( '请输入部署地区,例如:us-central1\n支持使用模型映射格式\n' + - '{\n' + - ' "default": "us-central1",\n' + - ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + - '}', + '{\n' + + ' "default": "us-central1",\n' + + ' "claude-3-5-sonnet-20240620": "europe-west1"\n' + + '}', )} autosize={{ minRows: 2 }} onChange={(value) => { @@ -825,6 +838,22 @@ const EditChannel = (props) => { /> )} + {inputs.type === 49 && ( + <> +
+ 智能体ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
{t('模型')}:
diff --git a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js index 800e9636..73626351 100644 --- a/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js +++ b/web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js @@ -6,6 +6,7 @@ import { showError, showSuccess, showWarning, + verifyJSON, } from '../../../helpers'; import { useTranslation } from 'react-i18next'; @@ -18,6 +19,7 @@ export default function RequestRateLimit(props) { ModelRequestRateLimitCount: -1, ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitDurationMinutes: 1, + ModelRequestRateLimitGroup: '', }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -46,6 +48,13 @@ export default function RequestRateLimit(props) { if (res.includes(undefined)) return showError(t('部分保存失败,请重试')); } + + for (let i = 0; i < res.length; i++) { + if (!res[i].data.success) { + return showError(res[i].data.message); + } + } + showSuccess(t('保存成功')); props.refresh(); }) @@ -147,6 +156,41 @@ export default function RequestRateLimit(props) { /> + + + verifyJSON(value), + message: t('不是合法的 JSON 字符串'), + }, + ]} + extraText={ +
+

{t('说明:')}

+
    +
  • {t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}
  • +
  • {t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}
  • +
  • {t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}
  • +
  • {t('分组速率配置优先级高于全局速率限制。')}
  • +
  • {t('限制周期统一使用上方配置的“限制周期”值。')}
  • +
+
+ } + onChange={(value) => { + setInputs({ ...inputs, ModelRequestRateLimitGroup: value }); + }} + /> + +