Merge branch 'main' into ui

This commit is contained in:
Apple\Apple
2025-05-17 11:22:22 +08:00
53 changed files with 1596 additions and 325 deletions

View File

@@ -107,7 +107,7 @@ For detailed configuration instructions, please refer to [Installation Guide-Env
- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16` - `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16`
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20` - `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20`
- `CRYPTO_SECRET`: Encryption key used for encrypting database content - `CRYPTO_SECRET`: Encryption key used for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2024-12-01-preview` - `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes - `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2` - `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`

View File

@@ -107,7 +107,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认 `16` - `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认 `16`
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小单位MB默认 `20` - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小单位MB默认 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容 - `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2024-12-01-preview` - `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2025-04-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟 - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2` - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`

View File

@@ -240,6 +240,7 @@ const (
ChannelTypeBaiduV2 = 46 ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47 ChannelTypeXinference = 47
ChannelTypeXai = 48 ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -294,4 +295,5 @@ var ChannelBaseURLs = []string{
"https://qianfan.baidubce.com", //46 "https://qianfan.baidubce.com", //46
"", //47 "", //47
"https://api.x.ai", //48 "https://api.x.ai", //48
"https://api.coze.cn", //49
} }

5
constant/azure.go Normal file
View File

@@ -0,0 +1,5 @@
package constant
import "time"
var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()

View File

@@ -31,7 +31,7 @@ func InitEnv() {
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview") AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)

View File

@@ -108,6 +108,13 @@ type DeepSeekUsageResponse struct {
} `json:"balance_infos"` } `json:"balance_infos"`
} }
type OpenRouterCreditResponse struct {
Data struct {
TotalCredits float64 `json:"total_credits"`
TotalUsage float64 `json:"total_usage"`
} `json:"data"`
}
// GetAuthHeader get auth header // GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header { func GetAuthHeader(token string) http.Header {
h := http.Header{} h := http.Header{}
@@ -281,6 +288,22 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil return response.TotalAvailable, nil
} }
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
url := "https://openrouter.ai/api/v1/credits"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := OpenRouterCreditResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
balance := response.Data.TotalCredits - response.Data.TotalUsage
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
@@ -307,6 +330,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelSiliconFlowBalance(channel) return updateChannelSiliconFlowBalance(channel)
case common.ChannelTypeDeepSeek: case common.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel) return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }

View File

@@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "ModelRequestRateLimitGroup":
err = setting.CheckModelRequestRateLimitGroup(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} }
err = model.UpdateOption(option.Key, option.Value) err = model.UpdateOption(option.Key, option.Value)

View File

@@ -592,7 +592,14 @@ func UpdateSelf(c *gin.Context) {
user.Password = "" // rollback to what it should be user.Password = "" // rollback to what it should be
cleanUser.Password = "" cleanUser.Password = ""
} }
updatePassword := user.Password != "" updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if err := cleanUser.Update(updatePassword); err != nil { if err := cleanUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -608,6 +615,23 @@ func UpdateSelf(c *gin.Context) {
return return
} }
func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) {
var currentUser *model.User
currentUser, err = model.GetUserById(userId, true)
if err != nil {
return
}
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
err = fmt.Errorf("原密码错误")
return
}
if newPassword == "" {
return
}
updatePassword = true
return
}
func DeleteUser(c *gin.Context) { func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {

View File

@@ -12,6 +12,8 @@ type ImageRequest struct {
Style string `json:"style,omitempty"` Style string `json:"style,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"` ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
} }
type ImageResponse struct { type ImageResponse struct {

View File

@@ -195,28 +195,28 @@ type OutputTokenDetails struct {
} }
type OpenAIResponsesResponse struct { type OpenAIResponsesResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
CreatedAt int `json:"created_at"` CreatedAt int `json:"created_at"`
Status string `json:"status"` Status string `json:"status"`
Error *OpenAIError `json:"error,omitempty"` Error *OpenAIError `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"` Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"` MaxOutputTokens int `json:"max_output_tokens"`
Model string `json:"model"` Model string `json:"model"`
Output []ResponsesOutput `json:"output"` Output []ResponsesOutput `json:"output"`
ParallelToolCalls bool `json:"parallel_tool_calls"` ParallelToolCalls bool `json:"parallel_tool_calls"`
PreviousResponseID string `json:"previous_response_id"` PreviousResponseID string `json:"previous_response_id"`
Reasoning *Reasoning `json:"reasoning"` Reasoning *Reasoning `json:"reasoning"`
Store bool `json:"store"` Store bool `json:"store"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
ToolChoice string `json:"tool_choice"` ToolChoice string `json:"tool_choice"`
Tools []interface{} `json:"tools"` Tools []ResponsesToolsCall `json:"tools"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
Truncation string `json:"truncation"` Truncation string `json:"truncation"`
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
User json.RawMessage `json:"user"` User json.RawMessage `json:"user"`
Metadata json.RawMessage `json:"metadata"` Metadata json.RawMessage `json:"metadata"`
} }
type IncompleteDetails struct { type IncompleteDetails struct {
@@ -238,8 +238,12 @@ type ResponsesOutputContent struct {
} }
const ( const (
BuildInTools_WebSearch = "web_search_preview" BuildInToolWebSearchPreview = "web_search_preview"
BuildInTools_FileSearch = "file_search" BuildInToolFileSearch = "file_search"
)
const (
BuildInCallWebSearchCall = "web_search_call"
) )
const ( const (
@@ -250,6 +254,7 @@ const (
// ResponsesStreamResponse 用于处理 /v1/responses 流式响应 // ResponsesStreamResponse 用于处理 /v1/responses 流式响应
type ResponsesStreamResponse struct { type ResponsesStreamResponse struct {
Type string `json:"type"` Type string `json:"type"`
Response *OpenAIResponsesResponse `json:"response"` Response *OpenAIResponsesResponse `json:"response,omitempty"`
Delta string `json:"delta,omitempty"` Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"`
} }

View File

@@ -185,7 +185,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "gpt-image-1") modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode := relayconstant.RelayModeAudioSpeech relayMode := relayconstant.RelayModeAudioSpeech
@@ -213,6 +213,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type) c.Set("channel_type", channel.Type)
c.Set("channel_create_time", channel.CreatedTime)
c.Set("channel_setting", channel.GetSetting()) c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride()) c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
@@ -239,5 +240,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeMokaAI: case common.ChannelTypeMokaAI:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)
case common.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
} }
} }

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/limiter" "one-api/common/limiter"
"one-api/constant"
"one-api/setting" "one-api/setting"
"strconv" "strconv"
"time" "time"
@@ -93,25 +94,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
} }
//2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过使用令牌桶限流器 //2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过使用令牌桶限流器
totalKey := fmt.Sprintf("rateLimit:%s", userId) if totalMaxCount > 0 {
// 初始化 totalKey := fmt.Sprintf("rateLimit:%s", userId)
tb := limiter.New(ctx, rdb) // 初始化
allowed, err = tb.Allow( tb := limiter.New(ctx, rdb)
ctx, allowed, err = tb.Allow(
totalKey, ctx,
limiter.WithCapacity(int64(totalMaxCount)*duration), totalKey,
limiter.WithRate(int64(totalMaxCount)), limiter.WithCapacity(int64(totalMaxCount)*duration),
limiter.WithRequested(duration), limiter.WithRate(int64(totalMaxCount)),
) limiter.WithRequested(duration),
)
if err != nil { if err != nil {
fmt.Println("检查总请求数限制失败:", err.Error()) fmt.Println("检查总请求数限制失败:", err.Error())
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
return return
} }
if !allowed { if !allowed {
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
}
} }
// 4. 处理请求 // 4. 处理请求
@@ -173,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) {
totalMaxCount := setting.ModelRequestRateLimitCount totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
group := c.GetString("token_group")
if group == "" {
group = c.GetString(constant.ContextKeyUserGroup)
}
//获取分组的限流配置
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
if found {
totalMaxCount = groupTotalCount
successMaxCount = groupSuccessCount
}
// 根据存储类型选择并执行限流处理器 // 根据存储类型选择并执行限流处理器
if common.RedisEnabled { if common.RedisEnabled {
redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)

View File

@@ -67,6 +67,7 @@ func InitOptionMap() {
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = setting.WorkerUrl common.OptionMap["WorkerUrl"] = setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
common.OptionMap["PayAddress"] = "" common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = "" common.OptionMap["EpayId"] = ""
@@ -92,6 +93,7 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -256,6 +258,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.StopOnSensitiveEnabled = boolValue setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled": case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
} }
} }
switch key { switch key {
@@ -338,6 +342,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value) setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
case "ModelRequestRateLimitSuccessCount": case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitGroup":
err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval": case "DataExportInterval":

View File

@@ -18,6 +18,7 @@ type User struct {
Id int `json:"id"` Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"` Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled

View File

@@ -33,6 +33,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default: default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
} }

View File

@@ -1,16 +1,23 @@
package channel package channel
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"net/http" "net/http"
common2 "one-api/common" common2 "one-api/common"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting/operation_setting"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
) )
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
@@ -55,6 +62,9 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
} }
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return nil, fmt.Errorf("new request failed: %w", err) return nil, fmt.Errorf("new request failed: %w", err)
@@ -105,7 +115,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
} else { } else {
client = service.GetHttpClient() client = service.GetHttpClient()
} }
// 流式请求 ping 保活
var stopPinger func()
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
var pingerWg sync.WaitGroup
if info.IsStream {
helper.SetEventStreamHeaders(c)
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
var pingerCtx context.Context
pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
if pingEnabled {
pingerWg.Add(1)
gopool.Go(func() {
defer pingerWg.Done()
if pingInterval <= 0 {
pingInterval = helper.DefaultPingInterval
}
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
for {
select {
case <-ticker.C:
pingMutex.Lock()
err2 := helper.PingData(c)
pingMutex.Unlock()
if err2 != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
return
}
if common2.DebugEnabled {
println("SSE ping data sent.")
}
case <-pingerCtx.Done():
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
return
}
}
})
}
}
resp, err := client.Do(req) resp, err := client.Do(req)
// request结束后停止ping
if info.IsStream && pingEnabled {
stopPinger()
pingerWg.Wait()
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,132 @@
package coze
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/common"
"time"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
// ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
// ConvertClaudeRequest implements channel.Adaptor.
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertEmbeddingRequest implements channel.Adaptor.
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements channel.Adaptor.
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertOpenAIRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return convertCozeChatRequest(c, *request), nil
}
// ConvertOpenAIResponsesRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertRerankRequest implements channel.Adaptor.
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
// DoRequest implements channel.Adaptor.
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
if info.IsStream {
return channel.DoApiRequest(a, c, info, requestBody)
}
// 首先发送创建消息请求,成功后再发送获取消息请求
// 发送创建消息请求
resp, err := channel.DoApiRequest(a, c, info, requestBody)
if err != nil {
return nil, err
}
// 解析 resp
var cozeResponse CozeChatResponse
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(respBody, &cozeResponse)
if cozeResponse.Code != 0 {
return nil, errors.New(cozeResponse.Msg)
}
c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
c.Set("coze_chat_id", cozeResponse.Data.Id)
// 轮询检查消息是否完成
for {
err, isComplete := checkIfChatComplete(a, c, info)
if err != nil {
return nil, err
} else {
if isComplete {
break
}
}
time.Sleep(time.Second * 1)
}
// 发送获取消息请求
return getChatDetail(a, c, info)
}
// DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = cozeChatStreamHandler(c, resp, info)
} else {
err, usage = cozeChatHandler(c, resp, info)
}
return
}
// GetChannelName implements channel.Adaptor.
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
// GetModelList implements channel.Adaptor.
func (a *Adaptor) GetModelList() []string {
return ModelList
}
// GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
}
// Init implements channel.Adaptor.
func (a *Adaptor) Init(info *common.RelayInfo) {
}
// SetupRequestHeader implements channel.Adaptor.
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}

View File

@@ -0,0 +1,30 @@
package coze
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
"Baichuan4",
"abab6.5s-chat-pro",
"glm-4-0520",
"qwen-max",
"deepseek-r1",
"deepseek-v3",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-7b",
"step-1v-8k",
"step-1.5v-mini",
"Doubao-pro-32k",
"Doubao-pro-256k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-vision-lite-32k",
"Doubao-vision-pro-32k",
"Doubao-1.5-pro-vision-32k",
"Doubao-1.5-lite-32k",
"Doubao-1.5-pro-32k",
"Doubao-1.5-thinking-pro",
"Doubao-1.5-pro-256k",
}
var ChannelName = "coze"

78
relay/channel/coze/dto.go Normal file
View File

@@ -0,0 +1,78 @@
package coze
import "encoding/json"
type CozeError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type CozeEnterMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
type CozeChatRequest struct {
BotId string `json:"bot_id"`
UserId string `json:"user_id"`
AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
Stream bool `json:"stream,omitempty"`
CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
AutoSaveHistory bool `json:"auto_save_history,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ExtraParams json.RawMessage `json:"extra_params,omitempty"`
ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type CozeChatResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data CozeChatResponseData `json:"data"`
}
type CozeChatResponseData struct {
Id string `json:"id"`
ConversationId string `json:"conversation_id"`
BotId string `json:"bot_id"`
CreatedAt int64 `json:"created_at"`
LastError CozeError `json:"last_error"`
Status string `json:"status"`
Usage CozeChatUsage `json:"usage"`
}
type CozeChatUsage struct {
TokenCount int `json:"token_count"`
OutputCount int `json:"output_count"`
InputCount int `json:"input_count"`
}
type CozeChatDetailResponse struct {
Data []CozeChatV3MessageDetail `json:"data"`
Code int `json:"code"`
Msg string `json:"msg"`
Detail CozeResponseDetail `json:"detail"`
}
type CozeChatV3MessageDetail struct {
Id string `json:"id"`
Role string `json:"role"`
Type string `json:"type"`
BotId string `json:"bot_id"`
ChatId string `json:"chat_id"`
Content json.RawMessage `json:"content"`
MetaData json.RawMessage `json:"meta_data"`
CreatedAt int64 `json:"created_at"`
SectionId string `json:"section_id"`
UpdatedAt int64 `json:"updated_at"`
ContentType string `json:"content_type"`
ConversationId string `json:"conversation_id"`
ReasoningContent string `json:"reasoning_content"`
}
type CozeResponseDetail struct {
Logid string `json:"logid"`
}

View File

@@ -0,0 +1,300 @@
package coze
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
var messages []CozeEnterMessage
// 将 request的messages的role为user的content转换为CozeMessage
for _, message := range request.Messages {
if message.Role == "user" {
messages = append(messages, CozeEnterMessage{
Role: "user",
Content: message.Content,
// TODO: support more content type
ContentType: "text",
})
}
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
cozeRequest := &CozeChatRequest{
BotId: c.GetString("bot_id"),
UserId: user,
AdditionalMessages: messages,
Stream: request.Stream,
}
return cozeRequest
}
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if cozeResponse.Code != 0 {
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
}
// 从上下文获取 usage
var usage dto.Usage
usage.PromptTokens = c.GetInt("coze_input_count")
usage.CompletionTokens = c.GetInt("coze_output_count")
usage.TotalTokens = c.GetInt("coze_token_count")
response.Usage = usage
response.Id = helper.GetResponseID(c)
var responseContent json.RawMessage
for _, data := range cozeResponse.Data {
if data.Type == "answer" {
responseContent = data.Content
response.Created = data.CreatedAt
}
}
// 添加 response.Choices
response.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Role: "assistant", Content: responseContent},
FinishReason: "stop",
},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, &usage
}
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
var currentEvent string
var currentData string
var usage dto.Usage
for scanner.Scan() {
line := scanner.Text()
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
currentEvent = ""
currentData = ""
}
continue
}
if strings.HasPrefix(line, "event:") {
currentEvent = strings.TrimSpace(line[6:])
continue
}
if strings.HasPrefix(line, "data:") {
currentData = strings.TrimSpace(line[5:])
continue
}
}
// Last event
if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
}
if err := scanner.Err(); err != nil {
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
}
helper.Done(c)
if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
return nil, &usage
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
switch event {
case "conversation.chat.completed":
// 将 data 解析为 CozeChatResponseData
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
usage.PromptTokens = chatData.Usage.InputCount
usage.CompletionTokens = chatData.Usage.OutputCount
usage.TotalTokens = chatData.Usage.TokenCount
finishReason := "stop"
stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
helper.ObjectData(c, stopResponse)
case "conversation.message.delta":
// 将 data 解析为 CozeChatV3MessageDetail
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
*responseText += content
openaiResponse := dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: 0,
}
choice.Delta.SetContentString(content)
openaiResponse.Choices = append(openaiResponse.Choices, choice)
helper.ObjectData(c, openaiResponse)
case "error":
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return err, false
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return err, false
}
resp, err := doRequest(req, info) // 调用 doRequest
if err != nil {
return err, false
}
if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
return fmt.Errorf("resp is nil"), false
}
defer resp.Body.Close() // 确保响应体被关闭
// 解析 resp 到 CozeChatResponse
var cozeResponse CozeChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response body failed: %w", err), false
}
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return fmt.Errorf("unmarshal response body failed: %w", err), false
}
if cozeResponse.Data.Status == "completed" {
// 在上下文设置 usage
c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
return nil, true
} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
} else {
return nil, false
}
}
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
client, err = service.NewProxyHttpClient(proxyURL.(string))
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.Do(req)
if err != nil { // 增加对 client.Do(req) 返回错误的检查
return nil, fmt.Errorf("client.Do failed: %w", err)
}
// _ = resp.Body.Close()
return resp, nil
}

View File

@@ -391,6 +391,7 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
} }
// 删除所有的title字段 // 删除所有的title字段
delete(v, "title") delete(v, "title")
delete(v, "$schema")
// 如果type不为object和array则直接返回 // 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") { if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema return schema

View File

@@ -8,6 +8,7 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/textproto"
"one-api/common" "one-api/common"
constant2 "one-api/constant" constant2 "one-api/constant"
"one-api/dto" "one-api/dto"
@@ -25,8 +26,6 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"net/textproto"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -68,9 +67,6 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} }
if info.RelayMode == constant.RelayModeResponses {
return fmt.Sprintf("%s/v1/responses", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime { if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") { if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -93,7 +89,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/") task := strings.TrimPrefix(requestURL, "/v1/")
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
model_ = strings.Replace(model_, ".", "", -1) // 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67 // https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == constant.RelayModeRealtime { if info.RelayMode == constant.RelayModeRealtime {
@@ -173,7 +172,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = request.Model info.UpstreamModelName = request.Model
// o系列模型developer适配o1-mini除外 // o系列模型developer适配o1-mini除外
if !strings.HasPrefix(request.Model, "o1-mini") { if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
//修改第一个Message的内容将system改为developer //修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" { if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer" request.Messages[0].Role = "developer"
@@ -429,7 +428,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = OaiResponsesStreamHandler(c, resp, info) err, usage = OaiResponsesStreamHandler(c, resp, info)
} else { } else {
err, usage = OpenaiResponsesHandler(c, resp, info) err, usage = OaiResponsesHandler(c, resp, info)
} }
default: default:
if info.IsStream { if info.IsStream {

View File

@@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
} }
} }
} }
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" {
return
}
helper.ResponseChunkData(c, streamResponse, data)
}

View File

@@ -216,9 +216,34 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
}, nil }, nil
} }
forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case relaycommon.RelayFormatOpenAI:
break if forceFormat {
responseBody, err = json.Marshal(simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
} else {
break
}
case relaycommon.RelayFormatClaude: case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := json.Marshal(claudeResp) claudeRespStr, err := json.Marshal(claudeResp)
@@ -244,18 +269,6 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
common.SysError("error copying response body: " + err.Error()) common.SysError("error copying response body: " + err.Error())
} }
resp.Body.Close() resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
return nil, &simpleResponse.Usage return nil, &simpleResponse.Usage
} }
@@ -644,102 +657,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
} }
return nil, &usageResp.Usage return nil, &usageResp.Usage
} }
func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = common.DecodeJson(responseBody, &responsesResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if responsesResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: responsesResponse.Error.Message,
Type: "openai_error",
Code: responsesResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
// reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
// copy response body
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
return nil, &usage
}
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
}
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
}
}
return true
})
if usage.CompletionTokens == 0 {
// 计算输出文本的 token 数量
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
return nil, usage
}
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" {
return
}
helper.ResponseChunkData(c, streamResponse, data)
}

View File

@@ -0,0 +1,119 @@
package openai
import (
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = common.DecodeJson(responseBody, &responsesResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if responsesResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: responsesResponse.Error.Message,
Type: "openai_error",
Code: responsesResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
// reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
// copy response body
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
// 解析 Tools 用量
for _, tool := range responsesResponse.Tools {
info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
}
return nil, &usage
}
func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
}
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
case dto.ResponsesOutputTypeItemDone:
// 函数调用处理
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
}
}
}
}
return true
})
if usage.CompletionTokens == 0 {
// 计算输出文本的 token 数量
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
return nil, usage
}

View File

@@ -11,8 +11,8 @@ import (
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
"one-api/relay/channel/gemini" "one-api/relay/channel/gemini"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/setting/model_setting"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"

View File

@@ -2,14 +2,16 @@ package xai
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"strings" "strings"
"one-api/relay/constant"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -28,15 +30,20 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
request.Size = "" xaiRequest := ImageRequest{
return request, nil Model: request.Model,
Prompt: request.Prompt,
N: request.N,
ResponseFormat: request.ResponseFormat,
}
return xaiRequest, nil
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -89,15 +96,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { switch info.RelayMode {
err, usage = xAIStreamHandler(c, resp, info) case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
} else { err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
err, usage = xAIHandler(c, resp, info) default:
if info.IsStream {
err, usage = xAIStreamHandler(c, resp, info)
} else {
err, usage = xAIHandler(c, resp, info)
}
} }
//if _, ok := usage.(*dto.Usage); ok && usage != nil {
// usage.(*dto.Usage).CompletionTokens = usage.(*dto.Usage).TotalTokens - usage.(*dto.Usage).PromptTokens
//}
return return
} }

View File

@@ -12,3 +12,16 @@ type ChatCompletionResponse struct {
Usage *dto.Usage `json:"usage"` Usage *dto.Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
} }
// quality, size or style are not supported by xAI API at the moment.
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
// Size string `json:"size,omitempty"`
// Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
// Style string `json:"style,omitempty"`
// User string `json:"user,omitempty"`
// ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
}

View File

@@ -36,6 +36,7 @@ type ClaudeConvertInfo struct {
const ( const (
RelayFormatOpenAI = "openai" RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude" RelayFormatClaude = "claude"
RelayFormatGemini = "gemini"
) )
type RerankerInfo struct { type RerankerInfo struct {
@@ -43,6 +44,16 @@ type RerankerInfo struct {
ReturnDocuments bool ReturnDocuments bool
} }
type BuildInToolInfo struct {
ToolName string
CallCount int
SearchContextSize string
}
type ResponsesUsageInfo struct {
BuiltInTools map[string]*BuildInToolInfo
}
type RelayInfo struct { type RelayInfo struct {
ChannelType int ChannelType int
ChannelId int ChannelId int
@@ -87,9 +98,11 @@ type RelayInfo struct {
UserQuota int UserQuota int
RelayFormat string RelayFormat string
SendResponseCount int SendResponseCount int
ChannelCreateTime int64
ThinkingContentInfo ThinkingContentInfo
*ClaudeConvertInfo *ClaudeConvertInfo
*RerankerInfo *RerankerInfo
*ResponsesUsageInfo
} }
// 定义支持流式选项的通道类型 // 定义支持流式选项的通道类型
@@ -103,6 +116,8 @@ var streamSupportedChannels = map[int]bool{
common.ChannelTypeVolcEngine: true, common.ChannelTypeVolcEngine: true,
common.ChannelTypeOllama: true, common.ChannelTypeOllama: true,
common.ChannelTypeXai: true, common.ChannelTypeXai: true,
common.ChannelTypeDeepSeek: true,
common.ChannelTypeBaiduV2: true,
} }
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
@@ -134,6 +149,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info return info
} }
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
if len(req.Tools) > 0 {
for _, tool := range req.Tools {
info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{
ToolName: tool.Type,
CallCount: 0,
}
switch tool.Type {
case dto.BuildInToolWebSearchPreview:
if tool.SearchContextSize == "" {
tool.SearchContextSize = "medium"
}
info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize
}
}
}
info.IsStream = req.Stream
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@@ -170,14 +210,15 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
OriginModelName: c.GetString("original_model"), OriginModelName: c.GetString("original_model"),
UpstreamModelName: c.GetString("original_model"), UpstreamModelName: c.GetString("original_model"),
//RecodeModelName: c.GetString("original_model"), //RecodeModelName: c.GetString("original_model"),
IsModelMapped: false, IsModelMapped: false,
ApiType: apiType, ApiType: apiType,
ApiVersion: c.GetString("api_version"), ApiVersion: c.GetString("api_version"),
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"), Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting, ChannelSetting: channelSetting,
ParamOverride: paramOverride, ChannelCreateTime: c.GetInt64("channel_create_time"),
RelayFormat: RelayFormatOpenAI, ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{ ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true, IsFirstThinkingContent: true,
SendLastThinkingContent: false, SendLastThinkingContent: false,

View File

@@ -33,6 +33,7 @@ const (
APITypeOpenRouter APITypeOpenRouter
APITypeXinference APITypeXinference
APITypeXai APITypeXai
APITypeCoze
APITypeDummy // this one is only for count, do not add any channel after this APITypeDummy // this one is only for count, do not add any channel after this
) )
@@ -95,6 +96,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeXinference apiType = APITypeXinference
case common.ChannelTypeXai: case common.ChannelTypeXai:
apiType = APITypeXai apiType = APITypeXai
case common.ChannelTypeCoze:
apiType = APITypeCoze
} }
if apiType == -1 { if apiType == -1 {
return APITypeOpenAI, false return APITypeOpenAI, false

View File

@@ -12,11 +12,19 @@ import (
) )
func SetEventStreamHeaders(c *gin.Context) { func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream") // 检查是否已经设置过头部
c.Writer.Header().Set("Cache-Control", "no-cache") if _, exists := c.Get("event_stream_headers_set"); exists {
c.Writer.Header().Set("Connection", "keep-alive") return
c.Writer.Header().Set("Transfer-Encoding", "chunked") }
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
// 设置标志,表示头部已经设置过
c.Set("event_stream_headers_set", true)
} }
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
@@ -37,7 +45,7 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) { func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
if flusher, ok := c.Writer.(http.Flusher); ok { if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
@@ -45,7 +53,7 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) { func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)}) c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
if flusher, ok := c.Writer.(http.Flusher); ok { if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }

View File

@@ -2,9 +2,11 @@ package helper
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"one-api/relay/common" "one-api/relay/common"
"github.com/gin-gonic/gin"
) )
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
@@ -16,9 +18,36 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
if err != nil { if err != nil {
return fmt.Errorf("unmarshal_model_mapping_failed") return fmt.Errorf("unmarshal_model_mapping_failed")
} }
if modelMap[info.OriginModelName] != "" {
info.UpstreamModelName = modelMap[info.OriginModelName] // 支持链式模型重定向,最终使用链尾的模型
info.IsModelMapped = true currentModel := info.OriginModelName
visitedModels := map[string]bool{
currentModel: true,
}
for {
if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
// 模型重定向循环检测,避免无限循环
if visitedModels[mappedModel] {
if mappedModel == currentModel {
if currentModel == info.OriginModelName {
info.IsModelMapped = false
return nil
} else {
info.IsModelMapped = true
break
}
}
return errors.New("model_mapping_contains_cycle")
}
visitedModels[mappedModel] = true
currentModel = mappedModel
info.IsModelMapped = true
} else {
break
}
}
if info.IsModelMapped {
info.UpstreamModelName = currentModel
} }
} }
return nil return nil

View File

@@ -23,7 +23,7 @@ type PriceData struct {
} }
func (p PriceData) ToSetting() string { func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
} }
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {

View File

@@ -3,7 +3,6 @@ package helper
import ( import (
"bufio" "bufio"
"context" "context"
"github.com/bytedance/gopkg/util/gopool"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@@ -14,6 +13,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -49,11 +49,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
// Not "256x256", "512x512", or "1024x1024" // Not "256x256", "512x512", or "1024x1024"
if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
} }
} else if imageRequest.Model == "dall-e-3" { } else if imageRequest.Model == "dall-e-3" {
if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
} }
if imageRequest.Quality == "" { if imageRequest.Quality == "" {
imageRequest.Quality = "standard" imageRequest.Quality = "standard"

View File

@@ -19,7 +19,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) { func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{} request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request) err := common.UnmarshalBodyReusable(c, request)
if err != nil { if err != nil {
@@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay
if len(request.Input) == 0 { if len(request.Input) == 0 {
return nil, errors.New("input is required") return nil, errors.New("input is required")
} }
relayInfo.IsStream = request.Stream
return request, nil return request, nil
} }
func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) { func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input) sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
return sensitiveWords, err return sensitiveWords, err
} }
@@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo
} }
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) req, err := getAndValidateResponsesRequest(c)
req, err := getAndValidateResponsesRequest(c, relayInfo)
if err != nil { if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error())) common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
} }
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
if setting.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkInputSensitive(req, relayInfo) sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil { if err != nil {

View File

@@ -18,6 +18,7 @@ import (
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"one-api/setting/model_setting" "one-api/setting/model_setting"
"one-api/setting/operation_setting"
"strings" "strings"
"time" "time"
@@ -193,6 +194,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
var httpResp *http.Response var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
@@ -358,6 +360,34 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
ratio := dModelRatio.Mul(dGroupRatio) ratio := dModelRatio.Mul(dGroupRatio)
// openai web search 工具计费
var dWebSearchQuota decimal.Decimal
var webSearchPrice float64
if relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 $%s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
}
}
// file search tool 计费
var dFileSearchQuota decimal.Decimal
var fileSearchPrice float64
if relayInfo.ResponsesUsageInfo != nil {
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s",
fileSearchTool.CallCount, dFileSearchQuota.String())
}
}
var quotaCalculateDecimal decimal.Decimal var quotaCalculateDecimal decimal.Decimal
if !priceData.UsePrice { if !priceData.UsePrice {
nonCachedTokens := dPromptTokens.Sub(dCacheTokens) nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
@@ -380,6 +410,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else { } else {
quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
} }
// 添加 responses tools call 调用的配额
quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
quota := int(quotaCalculateDecimal.Round(0).IntPart()) quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
@@ -430,6 +463,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
other["image_ratio"] = imageRatio other["image_ratio"] = imageRatio
other["image_output"] = imageTokens other["image_output"] = imageTokens
} }
if !dWebSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
other["web_search"] = true
other["web_search_call_count"] = webSearchTool.CallCount
other["web_search_price"] = webSearchPrice
}
}
if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
other["file_search"] = true
other["file_search_call_count"] = fileSearchTool.CallCount
other["file_search_price"] = fileSearchPrice
}
}
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
} }

View File

@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/claude" "one-api/relay/channel/claude"
"one-api/relay/channel/cloudflare" "one-api/relay/channel/cloudflare"
"one-api/relay/channel/cohere" "one-api/relay/channel/cohere"
"one-api/relay/channel/coze"
"one-api/relay/channel/deepseek" "one-api/relay/channel/deepseek"
"one-api/relay/channel/dify" "one-api/relay/channel/dify"
"one-api/relay/channel/gemini" "one-api/relay/channel/gemini"
@@ -88,6 +89,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &openai.Adaptor{} return &openai.Adaptor{}
case constant.APITypeXai: case constant.APITypeXai:
return &xai.Adaptor{} return &xai.Adaptor{}
case constant.APITypeCoze:
return &coze.Adaptor{}
} }
return nil return nil
} }

View File

@@ -24,7 +24,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
if !setting.EnableWorker() { if !setting.EnableWorker() {
return nil, fmt.Errorf("worker not enabled") return nil, fmt.Errorf("worker not enabled")
} }
if !strings.HasPrefix(req.URL, "https") { if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
return nil, fmt.Errorf("only support https url") return nil, fmt.Errorf("only support https url")
} }

View File

@@ -3,12 +3,13 @@ package service
import ( import (
"context" "context"
"fmt" "fmt"
"golang.org/x/net/proxy"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"time" "time"
"golang.org/x/net/proxy"
) )
var httpClient *http.Client var httpClient *http.Client
@@ -55,7 +56,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
}, },
}, nil }, nil
case "socks5": case "socks5", "socks5h":
// 获取认证信息 // 获取认证信息
var auth *proxy.Auth var auth *proxy.Auth
if parsedURL.User != nil { if parsedURL.User != nil {
@@ -69,6 +70,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
} }
// 创建 SOCKS5 代理拨号器 // 创建 SOCKS5 代理拨号器
// proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同
dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct) dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -120,11 +120,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
var config image.Config var config image.Config
var err error var err error
var format string var format string
var b64str string
if strings.HasPrefix(imageUrl.Url, "http") { if strings.HasPrefix(imageUrl.Url, "http") {
config, format, err = DecodeUrlImageData(imageUrl.Url) config, format, err = DecodeUrlImageData(imageUrl.Url)
} else { } else {
common.SysLog(fmt.Sprintf("decoding image")) common.SysLog(fmt.Sprintf("decoding image"))
config, format, _, err = DecodeBase64ImageData(imageUrl.Url) config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
} }
if err != nil { if err != nil {
return 0, err return 0, err
@@ -132,7 +133,12 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
imageUrl.MimeType = format imageUrl.MimeType = format
if config.Width == 0 || config.Height == 0 { if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) // not an image
if format != "" && b64str != "" {
// file type
return 3 * baseTokens, nil
}
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
} }
shortSide := config.Width shortSide := config.Width

View File

@@ -0,0 +1,57 @@
package operation_setting
import "strings"
const (
// Web search
WebSearchHighTierModelPriceLow = 30.00
WebSearchHighTierModelPriceMedium = 35.00
WebSearchHighTierModelPriceHigh = 50.00
WebSearchPriceLow = 25.00
WebSearchPriceMedium = 27.50
WebSearchPriceHigh = 30.00
// File search
FileSearchPrice = 2.5
)
func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 {
// 确定模型类型
// https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费
// gpt-4.1, gpt-4o, or gpt-4o-search-preview 更贵gpt-4.1-mini, gpt-4o-mini, gpt-4o-mini-search-preview 更便宜
isHighTierModel := (strings.HasPrefix(modelName, "gpt-4.1") || strings.HasPrefix(modelName, "gpt-4o")) &&
!strings.Contains(modelName, "mini")
// 确定 search context size 对应的价格
var priceWebSearchPerThousandCalls float64
switch contextSize {
case "low":
if isHighTierModel {
priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceLow
} else {
priceWebSearchPerThousandCalls = WebSearchPriceLow
}
case "medium":
if isHighTierModel {
priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium
} else {
priceWebSearchPerThousandCalls = WebSearchPriceMedium
}
case "high":
if isHighTierModel {
priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceHigh
} else {
priceWebSearchPerThousandCalls = WebSearchPriceHigh
}
default:
// search context size 默认为 medium
if isHighTierModel {
priceWebSearchPerThousandCalls = WebSearchHighTierModelPriceMedium
} else {
priceWebSearchPerThousandCalls = WebSearchPriceMedium
}
}
return priceWebSearchPerThousandCalls
}
func GetFileSearchPricePerThousand() float64 {
return FileSearchPrice
}

View File

@@ -1,6 +1,64 @@
package setting package setting
import (
"encoding/json"
"fmt"
"one-api/common"
"sync"
)
var ModelRequestRateLimitEnabled = false var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000 var ModelRequestRateLimitSuccessCount = 1000
var ModelRequestRateLimitGroup = map[string][2]int{}
var ModelRequestRateLimitMutex sync.RWMutex
func ModelRequestRateLimitGroup2JSONString() string {
ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitMutex.RUnlock()
jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitMutex.RUnlock()
ModelRequestRateLimitGroup = make(map[string][2]int)
return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
}
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitMutex.RUnlock()
if ModelRequestRateLimitGroup == nil {
return 0, 0, false
}
limits, found := ModelRequestRateLimitGroup[group]
if !found {
return 0, 0, false
}
return limits[0], limits[1], true
}
func CheckModelRequestRateLimitGroup(jsonStr string) error {
checkModelRequestRateLimitGroup := make(map[string][2]int)
err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup)
if err != nil {
return err
}
for group, limits := range checkModelRequestRateLimitGroup {
if limits[0] < 0 || limits[1] < 1 {
return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
}
}
return nil
}

View File

@@ -3,6 +3,7 @@ package setting
var ServerAddress = "http://localhost:3000" var ServerAddress = "http://localhost:3000"
var WorkerUrl = "" var WorkerUrl = ""
var WorkerValidKey = "" var WorkerValidKey = ""
var WorkerAllowHttpImageRequestEnabled = false
func EnableWorker() bool { func EnableWorker() bool {
return WorkerUrl != "" return WorkerUrl != ""

View File

@@ -618,7 +618,6 @@ const LogsTable = () => {
</Paragraph> </Paragraph>
); );
} }
let content = other?.claude let content = other?.claude
? renderClaudeModelPriceSimple( ? renderClaudeModelPriceSimple(
other.model_ratio, other.model_ratio,
@@ -935,6 +934,13 @@ const LogsTable = () => {
other.model_price, other.model_price,
other.group_ratio, other.group_ratio,
other?.user_group_ratio, other?.user_group_ratio,
false,
1.0,
undefined,
other.web_search || false,
other.web_search_call_count || 0,
other.file_search || false,
other.file_search_call_count || 0,
), ),
}); });
} }
@@ -995,6 +1001,12 @@ const LogsTable = () => {
other?.image || false, other?.image || false,
other?.image_ratio || 0, other?.image_ratio || 0,
other?.image_output || 0, other?.image_output || 0,
other?.web_search || false,
other?.web_search_call_count || 0,
other?.web_search_price || 0,
other?.file_search || false,
other?.file_search_call_count || 0,
other?.file_search_price || 0,
); );
} }
expandDataLocal.push({ expandDataLocal.push({

View File

@@ -57,6 +57,7 @@ const PersonalSetting = () => {
email_verification_code: '', email_verification_code: '',
email: '', email: '',
self_account_deletion_confirmation: '', self_account_deletion_confirmation: '',
original_password: '',
set_new_password: '', set_new_password: '',
set_new_password_confirmation: '', set_new_password_confirmation: '',
}); });
@@ -239,11 +240,24 @@ const PersonalSetting = () => {
}; };
const changePassword = async () => { const changePassword = async () => {
if (inputs.original_password === '') {
showError(t('请输入原密码!'));
return;
}
if (inputs.set_new_password === '') {
showError(t('请输入新密码!'));
return;
}
if (inputs.original_password === inputs.set_new_password) {
showError(t('新密码需要和原密码不一致!'));
return;
}
if (inputs.set_new_password !== inputs.set_new_password_confirmation) { if (inputs.set_new_password !== inputs.set_new_password_confirmation) {
showError(t('两次输入的密码不一致!')); showError(t('两次输入的密码不一致!'));
return; return;
} }
const res = await API.put(`/api/user/self`, { const res = await API.put(`/api/user/self`, {
original_password: inputs.original_password,
password: inputs.set_new_password, password: inputs.set_new_password,
}); });
const { success, message } = res.data; const { success, message } = res.data;
@@ -816,8 +830,8 @@ const PersonalSetting = () => {
</div> </div>
</Card> </Card>
<Card style={{ marginTop: 10 }}> <Card style={{ marginTop: 10 }}>
<Tabs type="line" defaultActiveKey="notification"> <Tabs type='line' defaultActiveKey='notification'>
<TabPane tab={t('通知设置')} itemKey="notification"> <TabPane tab={t('通知设置')} itemKey='notification'>
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('通知方式')}</Typography.Text> <Typography.Text strong>{t('通知方式')}</Typography.Text>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
@@ -993,23 +1007,36 @@ const PersonalSetting = () => {
</Typography.Text> </Typography.Text>
</div> </div>
</TabPane> </TabPane>
<TabPane tab={t('价格设置')} itemKey="price"> <TabPane tab={t('价格设置')} itemKey='price'>
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Typography.Text strong>{t('接受未设置价格模型')}</Typography.Text> <Typography.Text strong>
{t('接受未设置价格模型')}
</Typography.Text>
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Checkbox <Checkbox
checked={notificationSettings.acceptUnsetModelRatioModel} checked={
onChange={e => handleNotificationSettingChange('acceptUnsetModelRatioModel', e.target.checked)} notificationSettings.acceptUnsetModelRatioModel
}
onChange={(e) =>
handleNotificationSettingChange(
'acceptUnsetModelRatioModel',
e.target.checked,
)
}
> >
{t('接受未设置价格模型')} {t('接受未设置价格模型')}
</Checkbox> </Checkbox>
<Typography.Text type="secondary" style={{ marginTop: 8, display: 'block' }}> <Typography.Text
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} type='secondary'
style={{ marginTop: 8, display: 'block' }}
>
{t(
'当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用',
)}
</Typography.Text> </Typography.Text>
</div> </div>
</div> </div>
</TabPane> </TabPane>
</Tabs> </Tabs>
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Button type='primary' onClick={saveNotificationSettings}> <Button type='primary' onClick={saveNotificationSettings}>
@@ -1118,6 +1145,16 @@ const PersonalSetting = () => {
> >
<div style={{ marginTop: 20 }}> <div style={{ marginTop: 20 }}>
<Input <Input
name='original_password'
placeholder={t('原密码')}
type='password'
value={inputs.original_password}
onChange={(value) =>
handleInputChange('original_password', value)
}
/>
<Input
style={{ marginTop: 20 }}
name='set_new_password' name='set_new_password'
placeholder={t('新密码')} placeholder={t('新密码')}
value={inputs.set_new_password} value={inputs.set_new_password}

View File

@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
ModelRequestRateLimitCount: 0, ModelRequestRateLimitCount: 0,
ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1, ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '',
}); });
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@@ -23,10 +24,14 @@ const RateLimitSetting = () => {
if (success) { if (success) {
let newInputs = {}; let newInputs = {};
data.forEach((item) => { data.forEach((item) => {
if (item.key.endsWith('Enabled')) { if (item.key === 'ModelRequestRateLimitGroup') {
newInputs[item.key] = item.value === 'true' ? true : false; item.value = JSON.stringify(JSON.parse(item.value), null, 2);
} else { }
newInputs[item.key] = item.value;
if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true' ? true : false;
} else {
newInputs[item.key] = item.value;
} }
}); });

View File

@@ -19,7 +19,7 @@ import {
verifyJSON, verifyJSON,
} from '../helpers/utils'; } from '../helpers/utils';
import { API } from '../helpers/api'; import { API } from '../helpers/api';
import axios from "axios"; import axios from 'axios';
const SystemSetting = () => { const SystemSetting = () => {
let [inputs, setInputs] = useState({ let [inputs, setInputs] = useState({
@@ -45,6 +45,7 @@ const SystemSetting = () => {
ServerAddress: '', ServerAddress: '',
WorkerUrl: '', WorkerUrl: '',
WorkerValidKey: '', WorkerValidKey: '',
WorkerAllowHttpImageRequestEnabled: '',
EpayId: '', EpayId: '',
EpayKey: '', EpayKey: '',
Price: 7.3, Price: 7.3,
@@ -111,6 +112,7 @@ const SystemSetting = () => {
case 'SMTPSSLEnabled': case 'SMTPSSLEnabled':
case 'LinuxDOOAuthEnabled': case 'LinuxDOOAuthEnabled':
case 'oidc.enabled': case 'oidc.enabled':
case 'WorkerAllowHttpImageRequestEnabled':
item.value = item.value === 'true'; item.value = item.value === 'true';
break; break;
case 'Price': case 'Price':
@@ -206,7 +208,11 @@ const SystemSetting = () => {
let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl); let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl);
const options = [ const options = [
{ key: 'WorkerUrl', value: WorkerUrl }, { key: 'WorkerUrl', value: WorkerUrl },
] {
key: 'WorkerAllowHttpImageRequestEnabled',
value: inputs.WorkerAllowHttpImageRequestEnabled ? 'true' : 'false',
},
];
if (inputs.WorkerValidKey !== '' || WorkerUrl === '') { if (inputs.WorkerValidKey !== '' || WorkerUrl === '') {
options.push({ key: 'WorkerValidKey', value: inputs.WorkerValidKey }); options.push({ key: 'WorkerValidKey', value: inputs.WorkerValidKey });
} }
@@ -302,7 +308,8 @@ const SystemSetting = () => {
const domain = emailToAdd.trim(); const domain = emailToAdd.trim();
// 验证域名格式 // 验证域名格式
const domainRegex = /^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/; const domainRegex =
/^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/;
if (!domainRegex.test(domain)) { if (!domainRegex.test(domain)) {
showError('邮箱域名格式不正确,请输入有效的域名,如 gmail.com'); showError('邮箱域名格式不正确,请输入有效的域名,如 gmail.com');
return; return;
@@ -577,6 +584,12 @@ const SystemSetting = () => {
/> />
</Col> </Col>
</Row> </Row>
<Form.Checkbox
field='WorkerAllowHttpImageRequestEnabled'
noLabel
>
允许 HTTP 协议图片请求适用于自部署代理
</Form.Checkbox>
<Button onClick={submitWorker}>更新Worker设置</Button> <Button onClick={submitWorker}>更新Worker设置</Button>
</Form.Section> </Form.Section>
</Card> </Card>
@@ -799,7 +812,13 @@ const SystemSetting = () => {
onChange={(value) => setEmailToAdd(value)} onChange={(value) => setEmailToAdd(value)}
style={{ marginTop: 16 }} style={{ marginTop: 16 }}
suffix={ suffix={
<Button theme="solid" type="primary" onClick={handleAddEmail}>添加</Button> <Button
theme='solid'
type='primary'
onClick={handleAddEmail}
>
添加
</Button>
} }
onEnterPress={handleAddEmail} onEnterPress={handleAddEmail}
/> />

View File

@@ -118,6 +118,11 @@ export const CHANNEL_OPTIONS = [
{ {
value: 48, value: 48,
color: 'blue', color: 'blue',
label: 'xAI' label: 'xAI',
} },
{
value: 49,
color: 'blue',
label: 'Coze',
},
]; ];

View File

@@ -317,6 +317,12 @@ export function renderModelPrice(
image = false, image = false,
imageRatio = 1.0, imageRatio = 1.0,
imageOutputTokens = 0, imageOutputTokens = 0,
webSearch = false,
webSearchCallCount = 0,
webSearchPrice = 0,
fileSearch = false,
fileSearchCallCount = 0,
fileSearchPrice = 0,
) { ) {
if (modelPrice !== -1) { if (modelPrice !== -1) {
return i18next.t( return i18next.t(
@@ -339,14 +345,17 @@ export function renderModelPrice(
// Calculate effective input tokens (non-cached + cached with ratio applied) // Calculate effective input tokens (non-cached + cached with ratio applied)
let effectiveInputTokens = let effectiveInputTokens =
inputTokens - cacheTokens + cacheTokens * cacheRatio; inputTokens - cacheTokens + cacheTokens * cacheRatio;
// Handle image tokens if present // Handle image tokens if present
if (image && imageOutputTokens > 0) { if (image && imageOutputTokens > 0) {
effectiveInputTokens = inputTokens - imageOutputTokens + imageOutputTokens * imageRatio; effectiveInputTokens =
inputTokens - imageOutputTokens + imageOutputTokens * imageRatio;
} }
let price = let price =
(effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio + (effectiveInputTokens / 1000000) * inputRatioPrice * groupRatio +
(completionTokens / 1000000) * completionRatioPrice * groupRatio; (completionTokens / 1000000) * completionRatioPrice * groupRatio +
(webSearchCallCount / 1000) * webSearchPrice * groupRatio +
(fileSearchCallCount / 1000) * fileSearchPrice * groupRatio;
return ( return (
<> <>
@@ -391,9 +400,23 @@ export function renderModelPrice(
)} )}
</p> </p>
)} )}
{webSearch && webSearchCallCount > 0 && (
<p>
{i18next.t('Web搜索价格${{price}} / 1K 次', {
price: webSearchPrice,
})}
</p>
)}
{fileSearch && fileSearchCallCount > 0 && (
<p>
{i18next.t('文件搜索价格:${{price}} / 1K 次', {
price: fileSearchPrice,
})}
</p>
)}
<p></p> <p></p>
<p> <p>
{cacheTokens > 0 && !image {cacheTokens > 0 && !image && !webSearch && !fileSearch
? i18next.t( ? i18next.t(
'输入 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', '输入 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}',
{ {
@@ -407,31 +430,82 @@ export function renderModelPrice(
total: price.toFixed(6), total: price.toFixed(6),
}, },
) )
: image && imageOutputTokens > 0 : image && imageOutputTokens > 0 && !webSearch && !fileSearch
? i18next.t( ? i18next.t(
'输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', '输入 {{nonImageInput}} tokens + 图片输入 {{imageInput}} tokens * {{imageRatio}} / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}',
{ {
nonImageInput: inputTokens - imageOutputTokens, nonImageInput: inputTokens - imageOutputTokens,
imageInput: imageOutputTokens, imageInput: imageOutputTokens,
imageRatio: imageRatio, imageRatio: imageRatio,
price: inputRatioPrice, price: inputRatioPrice,
completion: completionTokens, completion: completionTokens,
compPrice: completionRatioPrice, compPrice: completionRatioPrice,
ratio: groupRatio, ratio: groupRatio,
total: price.toFixed(6), total: price.toFixed(6),
}, },
) )
: i18next.t( : webSearch && webSearchCallCount > 0 && !image && !fileSearch
'输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}', ? i18next.t(
{ '输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}} = ${{total}}',
input: inputTokens, {
price: inputRatioPrice, input: inputTokens,
completion: completionTokens, price: inputRatioPrice,
compPrice: completionRatioPrice, completion: completionTokens,
ratio: groupRatio, compPrice: completionRatioPrice,
total: price.toFixed(6), ratio: groupRatio,
}, webSearchCallCount,
)} webSearchPrice,
total: price.toFixed(6),
},
)
: fileSearch &&
fileSearchCallCount > 0 &&
!image &&
!webSearch
? i18next.t(
'输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}',
{
input: inputTokens,
price: inputRatioPrice,
completion: completionTokens,
compPrice: completionRatioPrice,
ratio: groupRatio,
fileSearchCallCount,
fileSearchPrice,
total: price.toFixed(6),
},
)
: webSearch &&
webSearchCallCount > 0 &&
fileSearch &&
fileSearchCallCount > 0 &&
!image
? i18next.t(
'输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} + Web搜索 {{webSearchCallCount}}次 / 1K 次 * ${{webSearchPrice}} * {{ratio}}+ 文件搜索 {{fileSearchCallCount}}次 / 1K 次 * ${{fileSearchPrice}} * {{ratio}}= ${{total}}',
{
input: inputTokens,
price: inputRatioPrice,
completion: completionTokens,
compPrice: completionRatioPrice,
ratio: groupRatio,
webSearchCallCount,
webSearchPrice,
fileSearchCallCount,
fileSearchPrice,
total: price.toFixed(6),
},
)
: i18next.t(
'输入 {{input}} tokens / 1M tokens * ${{price}} + 输出 {{completion}} tokens / 1M tokens * ${{compPrice}} * 分组 {{ratio}} = ${{total}}',
{
input: inputTokens,
price: inputRatioPrice,
completion: completionTokens,
compPrice: completionRatioPrice,
ratio: groupRatio,
total: price.toFixed(6),
},
)}
</p> </p>
<p>{i18next.t('仅供参考,以实际扣费为准')}</p> <p>{i18next.t('仅供参考,以实际扣费为准')}</p>
</article> </article>
@@ -448,33 +522,56 @@ export function renderLogContent(
user_group_ratio, user_group_ratio,
image = false, image = false,
imageRatio = 1.0, imageRatio = 1.0,
useUserGroupRatio = undefined useUserGroupRatio = undefined,
webSearch = false,
webSearchCallCount = 0,
fileSearch = false,
fileSearchCallCount = 0,
) { ) {
const ratioLabel = useUserGroupRatio ? i18next.t('专属倍率') : i18next.t('分组倍率'); const ratioLabel = useUserGroupRatio
? i18next.t('专属倍率')
: i18next.t('分组倍率');
const ratio = useUserGroupRatio ? user_group_ratio : groupRatio; const ratio = useUserGroupRatio ? user_group_ratio : groupRatio;
if (modelPrice !== -1) { if (modelPrice !== -1) {
return i18next.t('模型价格 ${{price}}{{ratioType}} {{ratio}}', { return i18next.t('模型价格 ${{price}}{{ratioType}} {{ratio}}', {
price: modelPrice, price: modelPrice,
ratioType: ratioLabel, ratioType: ratioLabel,
ratio ratio,
}); });
} else { } else {
if (image) { if (image) {
return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}}{{ratioType}} {{ratio}}', { return i18next.t(
modelRatio: modelRatio, '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}},图片输入倍率 {{imageRatio}}{{ratioType}} {{ratio}}',
completionRatio: completionRatio, {
imageRatio: imageRatio, modelRatio: modelRatio,
ratioType: ratioLabel, completionRatio: completionRatio,
ratio imageRatio: imageRatio,
}); ratioType: ratioLabel,
ratio,
},
);
} else if (webSearch) {
return i18next.t(
'模型倍率 {{modelRatio}},输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}Web 搜索调用 {{webSearchCallCount}} 次',
{
modelRatio: modelRatio,
completionRatio: completionRatio,
ratioType: ratioLabel,
ratio,
webSearchCallCount,
},
);
} else { } else {
return i18next.t('模型倍率 {{modelRatio}},输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}', { return i18next.t(
modelRatio: modelRatio, '模型倍率 {{modelRatio}},输出倍率 {{completionRatio}}{{ratioType}} {{ratio}}',
completionRatio: completionRatio, {
ratioType: ratioLabel, modelRatio: modelRatio,
ratio completionRatio: completionRatio,
}); ratioType: ratioLabel,
ratio,
},
);
} }
} }
} }

View File

@@ -493,6 +493,7 @@
"默认": "default", "默认": "default",
"图片演示": "Image demo", "图片演示": "Image demo",
"注意系统请求的时模型名称中的点会被剔除例如gpt-4.1会请求为gpt-41所以在Azure部署的时候部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41", "注意系统请求的时模型名称中的点会被剔除例如gpt-4.1会请求为gpt-41所以在Azure部署的时候部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41",
"2025年5月10日后添加的渠道不需要再在部署的时候移除模型名称中的\".\"": "After May 10, 2025, channels added do not need to remove the dot in the model name during deployment",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota", "取消无限额度": "Cancel unlimited quota",
"取消": "Cancel", "取消": "Cancel",
@@ -1085,7 +1086,7 @@
"没有账户?": "No account? ", "没有账户?": "No account? ",
"请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com", "请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com": "Please enter AZURE_OPENAI_ENDPOINT, e.g.: https://docs-test-001.openai.azure.com",
"默认 API 版本": "Default API Version", "默认 API 版本": "Default API Version",
"请输入默认 API 版本例如2024-12-01-preview": "Please enter default API version, e.g.: 2024-12-01-preview.", "请输入默认 API 版本例如2025-04-01-preview": "Please enter default API version, e.g.: 2025-04-01-preview.",
"请为渠道命名": "Please name the channel", "请为渠道命名": "Please name the channel",
"请选择可以使用该渠道的分组": "Please select groups that can use this channel", "请选择可以使用该渠道的分组": "Please select groups that can use this channel",
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:", "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit Group ratios in system settings to add new groups:",

View File

@@ -24,7 +24,8 @@ import {
TextArea, TextArea,
Checkbox, Checkbox,
Banner, Banner,
Modal, ImagePreview Modal,
ImagePreview,
} from '@douyinfe/semi-ui'; } from '@douyinfe/semi-ui';
import { getChannelModels, loadChannelModels } from '../../components/utils.js'; import { getChannelModels, loadChannelModels } from '../../components/utils.js';
import { IconHelpCircle } from '@douyinfe/semi-icons'; import { IconHelpCircle } from '@douyinfe/semi-icons';
@@ -306,7 +307,7 @@ const EditChannel = (props) => {
fetchModels().then(); fetchModels().then();
fetchGroups().then(); fetchGroups().then();
if (isEdit) { if (isEdit) {
loadChannel().then(() => { }); loadChannel().then(() => {});
} else { } else {
setInputs(originInputs); setInputs(originInputs);
let localModels = getChannelModels(inputs.type); let localModels = getChannelModels(inputs.type);
@@ -477,24 +478,26 @@ const EditChannel = (props) => {
type={'warning'} type={'warning'}
description={ description={
<> <>
{t('注意系统请求的时模型名称中的点会被剔除例如gpt-4.1会请求为gpt-41所以在Azure部署的时候部署模型名称需要手动改为gpt-41')} {t(
<br /> '2025年5月10日后添加的渠道不需要再在部署的时候移除模型名称中的"."',
<Typography.Text )}
style={{ {/*<br />*/}
color: 'rgba(var(--semi-blue-5), 1)', {/*<Typography.Text*/}
userSelect: 'none', {/* style={{*/}
cursor: 'pointer', {/* color: 'rgba(var(--semi-blue-5), 1)',*/}
}} {/* userSelect: 'none',*/}
onClick={() => { {/* cursor: 'pointer',*/}
setModalImageUrl( {/* }}*/}
'/azure_model_name.png', {/* onClick={() => {*/}
); {/* setModalImageUrl(*/}
setIsModalOpenurl(true) {/* '/azure_model_name.png',*/}
{/* );*/}
{/* setIsModalOpenurl(true)*/}
}} {/* }}*/}
> {/*>*/}
{t('查看示例')} {/* {t('查看示例')}*/}
</Typography.Text> {/*</Typography.Text>*/}
</> </>
} }
></Banner> ></Banner>
@@ -522,7 +525,7 @@ const EditChannel = (props) => {
<Input <Input
label={t('默认 API 版本')} label={t('默认 API 版本')}
name='azure_other' name='azure_other'
placeholder={t('请输入默认 API 版本例如2024-12-01-preview')} placeholder={t('请输入默认 API 版本例如2025-04-01-preview')}
onChange={(value) => { onChange={(value) => {
handleInputChange('other', value); handleInputChange('other', value);
}} }}
@@ -584,25 +587,35 @@ const EditChannel = (props) => {
value={inputs.name} value={inputs.name}
autoComplete='new-password' autoComplete='new-password'
/> />
{inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && inputs.type !== 45 && ( {inputs.type !== 3 &&
<> inputs.type !== 8 &&
<div style={{ marginTop: 10 }}> inputs.type !== 22 &&
<Typography.Text strong>{t('API地址')}</Typography.Text> inputs.type !== 36 &&
</div> inputs.type !== 45 && (
<Tooltip content={t('对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写')}> <>
<Input <div style={{ marginTop: 10 }}>
label={t('API地址')} <Typography.Text strong>{t('API地址')}</Typography.Text>
name="base_url" </div>
placeholder={t('此项可选用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/')} <Tooltip
onChange={(value) => { content={t(
handleInputChange('base_url', value); '对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写',
}} )}
value={inputs.base_url} >
autoComplete="new-password" <Input
/> label={t('API地址')}
</Tooltip> name='base_url'
</> placeholder={t(
)} '此项可选用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/',
)}
onChange={(value) => {
handleInputChange('base_url', value);
}}
value={inputs.base_url}
autoComplete='new-password'
/>
</Tooltip>
</>
)}
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('密钥')}</Typography.Text> <Typography.Text strong>{t('密钥')}</Typography.Text>
</div> </div>
@@ -761,10 +774,10 @@ const EditChannel = (props) => {
name='other' name='other'
placeholder={t( placeholder={t(
'请输入部署地区例如us-central1\n支持使用模型映射格式\n' + '请输入部署地区例如us-central1\n支持使用模型映射格式\n' +
'{\n' + '{\n' +
' "default": "us-central1",\n' + ' "default": "us-central1",\n' +
' "claude-3-5-sonnet-20240620": "europe-west1"\n' + ' "claude-3-5-sonnet-20240620": "europe-west1"\n' +
'}', '}',
)} )}
autosize={{ minRows: 2 }} autosize={{ minRows: 2 }}
onChange={(value) => { onChange={(value) => {
@@ -825,6 +838,22 @@ const EditChannel = (props) => {
/> />
</> </>
)} )}
{inputs.type === 49 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>智能体ID</Typography.Text>
</div>
<Input
name='other'
placeholder={'请输入智能体ID例如7342866812345'}
onChange={(value) => {
handleInputChange('other', value);
}}
value={inputs.other}
autoComplete='new-password'
/>
</>
)}
<div style={{ marginTop: 10 }}> <div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('模型')}</Typography.Text> <Typography.Text strong>{t('模型')}</Typography.Text>
</div> </div>

View File

@@ -6,6 +6,7 @@ import {
showError, showError,
showSuccess, showSuccess,
showWarning, showWarning,
verifyJSON,
} from '../../../helpers'; } from '../../../helpers';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@@ -18,6 +19,7 @@ export default function RequestRateLimit(props) {
ModelRequestRateLimitCount: -1, ModelRequestRateLimitCount: -1,
ModelRequestRateLimitSuccessCount: 1000, ModelRequestRateLimitSuccessCount: 1000,
ModelRequestRateLimitDurationMinutes: 1, ModelRequestRateLimitDurationMinutes: 1,
ModelRequestRateLimitGroup: '',
}); });
const refForm = useRef(); const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs); const [inputsRow, setInputsRow] = useState(inputs);
@@ -46,6 +48,13 @@ export default function RequestRateLimit(props) {
if (res.includes(undefined)) if (res.includes(undefined))
return showError(t('部分保存失败,请重试')); return showError(t('部分保存失败,请重试'));
} }
for (let i = 0; i < res.length; i++) {
if (!res[i].data.success) {
return showError(res[i].data.message);
}
}
showSuccess(t('保存成功')); showSuccess(t('保存成功'));
props.refresh(); props.refresh();
}) })
@@ -147,6 +156,41 @@ export default function RequestRateLimit(props) {
/> />
</Col> </Col>
</Row> </Row>
<Row>
<Col xs={24} sm={16}>
<Form.TextArea
label={t('分组速率限制')}
placeholder={t(
'{\n "default": [200, 100],\n "vip": [0, 1000]\n}',
)}
field={'ModelRequestRateLimitGroup'}
autosize={{ minRows: 5, maxRows: 15 }}
trigger='blur'
stopValidateWithError
rules={[
{
validator: (rule, value) => verifyJSON(value),
message: t('不是合法的 JSON 字符串'),
},
]}
extraText={
<div>
<p style={{ marginBottom: -15 }}>{t('说明:')}</p>
<ul>
<li>{t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}</li>
<li>{t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}</li>
<li>{t('[最多请求次数]必须大于等于0[最多请求完成次数]必须大于等于1。')}</li>
<li>{t('分组速率配置优先级高于全局速率限制。')}</li>
<li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
</ul>
</div>
}
onChange={(value) => {
setInputs({ ...inputs, ModelRequestRateLimitGroup: value });
}}
/>
</Col>
</Row>
<Row> <Row>
<Button size='default' onClick={onSubmit}> <Button size='default' onClick={onSubmit}>
{t('保存模型速率限制')} {t('保存模型速率限制')}